From 9706ef4ab8beacd16bcdaa29ef11179982b796a6 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 21 May 2007 00:22:30 +0000 Subject: [PATCH] - redefined how the mapper determines primary keys. this is to help with the new deferred polymorphic loading. it takes stock of all the primary keys of all of its tables in all cases, including when a custom primary key is sent, to maximize its chances of being able to INSERT into each table. then, whether or not the custom primary key is sent, it gathers together columns which are equivalent via a foreign key relationship to each other or via a common parent column, similarly to how Join does it. this continues along the path first set up from [ticket:185]. so primary keys of mappers are always going to be "minimized" as far as number of columns. finally, the list of pk cols is normalized to the mapped table. this becomes the mapper's "primary key" and is distinct from all the per-table pk column collections. - added "deferred poly load" versions to magazine test, cut down on table recreates in polymorph test. --- lib/sqlalchemy/orm/mapper.py | 81 +++++++----- lib/sqlalchemy/orm/query.py | 2 +- lib/sqlalchemy/orm/util.py | 2 - lib/sqlalchemy/sql.py | 1 + test/orm/inheritance/magazine.py | 64 +++++---- test/orm/inheritance/manytomany.py | 3 +- test/orm/inheritance/polymorph.py | 201 +++++++++++++++-------------- test/orm/inheritance/polymorph2.py | 7 + 8 files changed, 200 insertions(+), 161 deletions(-) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 91e9999443..6ae4fd647e 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -484,39 +484,53 @@ class Mapper(object): # may be a join or other construct self.tables = sqlutil.TableFinder(self.mapped_table) - # determine primary key columns, either passed in, or get them from our set of tables + # determine primary key columns self.pks_by_table = {} + + # go through all of our represented tables + # and assemble primary key columns + for t in self.tables + [self.mapped_table]: + try: + l = self.pks_by_table[t] + except KeyError: + l = self.pks_by_table.setdefault(t, util.OrderedSet()) + for k in t.primary_key: + l.add(k) + if self.primary_key_argument is not None: - # determine primary keys using user-given list of primary key columns as a guide - # - # TODO: this might not work very well for joined-table and/or polymorphic - # inheritance mappers since local_table isnt taken into account nor is select_table - # need to test custom primary key columns used with inheriting mappers for k in self.primary_key_argument: self.pks_by_table.setdefault(k.table, util.OrderedSet()).add(k) - if k.table != self.mapped_table: - # associate pk cols from subtables to the "main" table - corr = self.mapped_table.corresponding_column(k, raiseerr=False) - if corr is not None: - self.pks_by_table.setdefault(self.mapped_table, util.OrderedSet()).add(corr) - else: - # no user-defined primary key columns - go through all of our represented tables - # and assemble primary key columns - for t in self.tables + [self.mapped_table]: - try: - l = self.pks_by_table[t] - except KeyError: - l = self.pks_by_table.setdefault(t, util.OrderedSet()) - for k in t.primary_key: - #if k.key not in t.c and k._label not in t.c: - # this is a condition that was occurring when table reflection was doubling up primary keys - # that were overridden in the Table constructor - # raise exceptions.AssertionError("Column " + str(k) + " not located in the column set of table " + str(t)) - l.add(k) - + if len(self.pks_by_table[self.mapped_table]) == 0: raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) - self.primary_key = self.pks_by_table[self.mapped_table] + + # create the "primary_key" for this mapper. this will flatten "equivalent" primary key columns + # into one column, where "equivalent" means that one column references the other via foreign key, or + # multiple columns that all reference a common parent column. it will also resolve the column + # against the "mapped_table" of this mapper. + primary_key = sql.ColumnCollection() + equivs = {} + for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]): + if not len(col.foreign_keys): + equivs.setdefault(col, util.Set()).add(col) + else: + for fk in col.foreign_keys: + equivs.setdefault(fk.column, util.Set()).add(col) + for col in equivs: + c = self.mapped_table.corresponding_column(col, raiseerr=False) + if c is None: + for cc in equivs[col]: + c = self.mapped_table.corresponding_column(cc, raiseerr=False) + if c is not None: + break + else: + raise exceptions.ArgumentError("Cant resolve column " + str(col)) + primary_key.add(c) + + if len(primary_key) == 0: + raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) + + self.primary_key = primary_key def _compile_properties(self): """Inspect the properties dictionary sent to the Mapper's @@ -576,8 +590,7 @@ class Mapper(object): # its a ColumnProperty - match the ultimate table columns # back to the property - proplist = self.columntoproperty.setdefault(column, []) - proplist.append(prop) + self.columntoproperty.setdefault(column, []).append(prop) def _initialize_properties(self): @@ -921,7 +934,7 @@ class Mapper(object): dictionary corresponding result-set ``ColumnElement`` instances to their values within a row. """ - return (self.class_, tuple([row[column] for column in self.pks_by_table[self.mapped_table]]), self.entity_name) + return (self.class_, tuple([row[column] for column in self.primary_key]), self.entity_name) def identity_key_from_primary_key(self, primary_key): """Return an identity-map key for use in storing/retrieving an @@ -946,7 +959,7 @@ class Mapper(object): instance. """ - return [self.get_attr_by_column(instance, column) for column in self.pks_by_table[self.mapped_table]] + return [self.get_attr_by_column(instance, column) for column in self.primary_key] def canload(self, instance): """return true if this mapper is capable of loading the given instance""" @@ -1434,14 +1447,14 @@ class Mapper(object): return instance else: if self.__should_log_debug: - self.__log_debug("_instance(): identity key %s not in session" % str(identitykey) + repr([mapperutil.instance_str(x) for x in context.session])) + self.__log_debug("_instance(): identity key %s not in session" % str(identitykey)) # look in result-local identitymap for it. exists = context.identity_map.has_key(identitykey) if not exists: if self.allow_null_pks: # check if *all* primary key cols in the result are None - this indicates # an instance of the object is not present in the row. - for col in self.pks_by_table[self.mapped_table]: + for col in self.primary_key: if row[col] is not None: break else: @@ -1449,7 +1462,7 @@ class Mapper(object): else: # otherwise, check if *any* primary key cols in the result are None - this indicates # an instance of the object is not present in the row. - for col in self.pks_by_table[self.mapped_table]: + for col in self.primary_key: if row[col] is None: return None diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index a87fa8d19a..6ed1a06d35 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -84,7 +84,7 @@ class Query(object): return self._session table = property(lambda s:s.select_mapper.mapped_table) - primary_key_columns = property(lambda s:s.select_mapper.pks_by_table[s.select_mapper.mapped_table]) + primary_key_columns = property(lambda s:s.select_mapper.primary_key) session = property(_get_session) def get(self, ident, **kwargs): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 923dd67977..e80d954bd9 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -89,8 +89,6 @@ class TranslatingDict(dict): def __translate_col(self, col): ourcol = self.selectable.corresponding_column(col, keys_ok=False, raiseerr=False) - #if col is not ourcol: - # print "TD TRANSLATING ", col, "TO", ourcol if ourcol is None: return col else: diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 9a93474dc7..ec22179efa 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -2599,6 +2599,7 @@ class TableClause(FromClause): return [c for c in self.c] else: return [] + def accept_visitor(self, visitor): visitor.visit_table(self) diff --git a/test/orm/inheritance/magazine.py b/test/orm/inheritance/magazine.py index a9c88ef60c..9a52919dd7 100644 --- a/test/orm/inheritance/magazine.py +++ b/test/orm/inheritance/magazine.py @@ -64,8 +64,8 @@ class MagazinePage(Page): class ClassifiedPage(MagazinePage): pass -class InheritTest(testbase.ORMTest): - """tests a large polymorphic relationship""" + +class MagazineTest(testbase.ORMTest): def define_tables(self, metadata): global publication_table, issue_table, location_table, location_name_table, magazine_table, \ page_table, magazine_page_table, classified_page_table, page_size_table @@ -116,6 +116,8 @@ class InheritTest(testbase.ORMTest): Column('name', String(45), default=''), ) +def generate_round_trip_test(use_unions=False): + def test_roundtrip(self): publication_mapper = mapper(Publication, publication_table) issue_mapper = mapper(Issue, issue_table, properties = { @@ -133,33 +135,42 @@ class InheritTest(testbase.ORMTest): page_size_mapper = mapper(PageSize, page_size_table) - page_join = polymorphic_union( - { - 'm': page_table.join(magazine_page_table), - 'c': page_table.join(magazine_page_table).join(classified_page_table), - 'p': page_table.select(page_table.c.type=='p'), - }, None, 'page_join') - - magazine_join = polymorphic_union( - { - 'm': page_table.join(magazine_page_table), - 'c': page_table.join(magazine_page_table).join(classified_page_table), - }, None, 'page_join') - magazine_mapper = mapper(Magazine, magazine_table, properties = { 'location': relation(Location, backref=backref('magazine', uselist=False)), 'size': relation(PageSize), }) - page_mapper = mapper(Page, page_table, select_table=page_join, polymorphic_on=page_join.c.type, polymorphic_identity='p') + if use_unions: + page_join = polymorphic_union( + { + 'm': page_table.join(magazine_page_table), + 'c': page_table.join(magazine_page_table).join(classified_page_table), + 'p': page_table.select(page_table.c.type=='p'), + }, None, 'page_join') + page_mapper = mapper(Page, page_table, select_table=page_join, polymorphic_on=page_join.c.type, polymorphic_identity='p') + else: + page_mapper = mapper(Page, page_table, polymorphic_on=page_table.c.type, polymorphic_identity='p') + + if use_unions: + magazine_join = polymorphic_union( + { + 'm': page_table.join(magazine_page_table), + 'c': page_table.join(magazine_page_table).join(classified_page_table), + }, None, 'page_join') + magazine_page_mapper = mapper(MagazinePage, magazine_page_table, select_table=magazine_join, inherits=page_mapper, polymorphic_identity='m', properties={ + 'magazine': relation(Magazine, backref=backref('pages', order_by=magazine_join.c.page_no)) + }) + else: + magazine_page_mapper = mapper(MagazinePage, magazine_page_table, inherits=page_mapper, polymorphic_identity='m', properties={ + 'magazine': relation(Magazine, backref=backref('pages', order_by=page_table.c.page_no)) + }) + + classified_page_mapper = mapper(ClassifiedPage, classified_page_table, inherits=magazine_page_mapper, polymorphic_identity='c', primary_key=[page_table.c.id]) + compile_mappers() + print [str(s) for s in classified_page_mapper.primary_key] + print classified_page_mapper.columntoproperty[page_table.c.id] - magazine_page_mapper = mapper(MagazinePage, magazine_page_table, select_table=magazine_join, inherits=page_mapper, polymorphic_identity='m', properties={ - 'magazine': relation(Magazine, backref=backref('pages', order_by=magazine_join.c.page_no)) - }) - classified_page_mapper = mapper(ClassifiedPage, classified_page_table, inherits=magazine_page_mapper, polymorphic_identity='c') - - def testone(self): session = create_session() pub = Publication(name='Test') @@ -174,7 +185,7 @@ class InheritTest(testbase.ORMTest): page2 = MagazinePage(magazine=magazine,page_no=2) page3 = ClassifiedPage(magazine=magazine,page_no=3) session.save(pub) - + session.flush() print [x for x in session] session.clear() @@ -186,6 +197,13 @@ class InheritTest(testbase.ORMTest): print p.issues[0].locations[0].magazine.pages print [page, page2, page3] assert repr(p.issues[0].locations[0].magazine.pages) == repr([page, page2, page3]) + + test_roundtrip.__name__ = "test_%s" % (not use_union and "Nounion" or "Unions") + setattr(MagazineTest, test_roundtrip.__name__, test_roundtrip) + +for use_union in [True, False]: + generate_round_trip_test(use_union) + if __name__ == '__main__': testbase.main() diff --git a/test/orm/inheritance/manytomany.py b/test/orm/inheritance/manytomany.py index 380ccf0041..f97b8ed0d5 100644 --- a/test/orm/inheritance/manytomany.py +++ b/test/orm/inheritance/manytomany.py @@ -100,7 +100,8 @@ class InheritTest2(testbase.ORMTest): mapper(Foo, foo) mapper(Bar, bar, inherits=Foo) - + print foo.join(bar).primary_key + print class_mapper(Bar).primary_key b = Bar('somedata') sess = create_session() sess.save(b) diff --git a/test/orm/inheritance/polymorph.py b/test/orm/inheritance/polymorph.py index 107203e472..0ef9984ae5 100644 --- a/test/orm/inheritance/polymorph.py +++ b/test/orm/inheritance/polymorph.py @@ -191,7 +191,10 @@ class RelationToSubclassTest(PolymorphTest): sess.query(Company).get_by(company_id=c.company_id) assert sets.Set([e.get_name() for e in c.managers]) == sets.Set(['pointy haired boss']) assert c.managers[0].company is c - + +class RoundTripTest(PolymorphTest): + pass + def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_colprop=False, use_literal_join=False, use_union=False): """generates a round trip test. @@ -200,119 +203,117 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co redefine_colprop - if we redefine the 'name' column to be 'people_name' on the base Person class use_literal_join - primary join condition is explicitly specified """ - class RoundTripTest(PolymorphTest): - def test_roundtrip(self): - # create a union that represents both types of joins. - if not use_union: - person_join = None - elif include_base: - person_join = polymorphic_union( - { - 'engineer':people.join(engineers), - 'manager':people.join(managers), - 'person':people.select(people.c.type=='person'), - }, None, 'pjoin') - else: - person_join = polymorphic_union( - { - 'engineer':people.join(engineers), - 'manager':people.join(managers), - }, None, 'pjoin') - - if redefine_colprop: - person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name}) - else: - person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=people.c.type, polymorphic_identity='person') - - mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') - mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') - - if use_literal_join: - mapper(Company, companies, properties={ - 'employees': relation(Person, lazy=lazy_relation, primaryjoin=people.c.company_id==companies.c.company_id, private=True, - backref="company" - ) - }) - else: - mapper(Company, companies, properties={ - 'employees': relation(Person, lazy=lazy_relation, private=True, - backref="company" - ) - }) - - if redefine_colprop: - person_attribute_name = 'person_name' - else: - person_attribute_name = 'name' + def test_roundtrip(self): + # create a union that represents both types of joins. + if not use_union: + person_join = None + elif include_base: + person_join = polymorphic_union( + { + 'engineer':people.join(engineers), + 'manager':people.join(managers), + 'person':people.select(people.c.type=='person'), + }, None, 'pjoin') + else: + person_join = polymorphic_union( + { + 'engineer':people.join(engineers), + 'manager':people.join(managers), + }, None, 'pjoin') + + if redefine_colprop: + person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name}) + else: + person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=people.c.type, polymorphic_identity='person') - session = create_session() - c = Company(name='company1') - c.employees.append(Manager(status='AAB', manager_name='manager1', **{person_attribute_name:'pointy haired boss'})) - c.employees.append(Engineer(status='BBA', engineer_name='engineer1', primary_language='java', **{person_attribute_name:'dilbert'})) - if include_base: - c.employees.append(Person(status='HHH', **{person_attribute_name:'joesmith'})) - c.employees.append(Engineer(status='CGG', engineer_name='engineer2', primary_language='python', **{person_attribute_name:'wally'})) - c.employees.append(Manager(status='ABA', manager_name='manager2', **{person_attribute_name:'jsmith'})) - session.save(c) - print session.new - session.flush() - session.clear() - id = c.company_id - c = session.query(Company).get(id) - for e in c.employees: - print e, e._instance_key, e.company - if include_base: - assert sets.Set([(e.get_name(), getattr(e, 'status', None)) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('joesmith', None), ('wally', 'CGG'), ('jsmith', 'ABA')]) - else: - assert sets.Set([(e.get_name(), e.status) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('wally', 'CGG'), ('jsmith', 'ABA')]) - print "\n" + mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') + mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') + if use_literal_join: + mapper(Company, companies, properties={ + 'employees': relation(Person, lazy=lazy_relation, primaryjoin=people.c.company_id==companies.c.company_id, private=True, + backref="company" + ) + }) + else: + mapper(Company, companies, properties={ + 'employees': relation(Person, lazy=lazy_relation, private=True, + backref="company" + ) + }) - # test selecting from the query, using the base mapped table (people) as the selection criterion. - # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join" - dilbert = session.query(Person).selectfirst(people.c.name=='dilbert') - dilbert2 = session.query(Engineer).selectfirst(people.c.name=='dilbert') - assert dilbert is dilbert2 - - # test selecting from the query, joining against an alias of the base "people" table. test that - # the "palias" alias does *not* get sucked up into the "person_join" conversion. - palias = people.alias("palias") - session.query(Person).selectfirst((palias.c.name=='dilbert') & (palias.c.person_id==people.c.person_id)) - dilbert2 = session.query(Engineer).selectfirst((palias.c.name=='dilbert') & (palias.c.person_id==people.c.person_id)) - assert dilbert is dilbert2 - - session.query(Person).selectfirst((engineers.c.engineer_name=="engineer1") & (engineers.c.person_id==people.c.person_id)) - dilbert2 = session.query(Engineer).selectfirst(engineers.c.engineer_name=="engineer1") - assert dilbert is dilbert2 - - - dilbert.engineer_name = 'hes dibert!' + if redefine_colprop: + person_attribute_name = 'person_name' + else: + person_attribute_name = 'name' + + session = create_session() + c = Company(name='company1') + c.employees.append(Manager(status='AAB', manager_name='manager1', **{person_attribute_name:'pointy haired boss'})) + c.employees.append(Engineer(status='BBA', engineer_name='engineer1', primary_language='java', **{person_attribute_name:'dilbert'})) + if include_base: + c.employees.append(Person(status='HHH', **{person_attribute_name:'joesmith'})) + c.employees.append(Engineer(status='CGG', engineer_name='engineer2', primary_language='python', **{person_attribute_name:'wally'})) + c.employees.append(Manager(status='ABA', manager_name='manager2', **{person_attribute_name:'jsmith'})) + session.save(c) + print session.new + session.flush() + session.clear() + id = c.company_id + c = session.query(Company).get(id) + for e in c.employees: + print e, e._instance_key, e.company + if include_base: + assert sets.Set([(e.get_name(), getattr(e, 'status', None)) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('joesmith', None), ('wally', 'CGG'), ('jsmith', 'ABA')]) + else: + assert sets.Set([(e.get_name(), e.status) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('wally', 'CGG'), ('jsmith', 'ABA')]) + print "\n" + + + # test selecting from the query, using the base mapped table (people) as the selection criterion. + # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join" + dilbert = session.query(Person).selectfirst(people.c.name=='dilbert') + dilbert2 = session.query(Engineer).selectfirst(people.c.name=='dilbert') + assert dilbert is dilbert2 + + # test selecting from the query, joining against an alias of the base "people" table. test that + # the "palias" alias does *not* get sucked up into the "person_join" conversion. + palias = people.alias("palias") + session.query(Person).selectfirst((palias.c.name=='dilbert') & (palias.c.person_id==people.c.person_id)) + dilbert2 = session.query(Engineer).selectfirst((palias.c.name=='dilbert') & (palias.c.person_id==people.c.person_id)) + assert dilbert is dilbert2 + + session.query(Person).selectfirst((engineers.c.engineer_name=="engineer1") & (engineers.c.person_id==people.c.person_id)) + dilbert2 = session.query(Engineer).selectfirst(engineers.c.engineer_name=="engineer1") + assert dilbert is dilbert2 + + + dilbert.engineer_name = 'hes dibert!' - session.flush() - session.clear() + session.flush() + session.clear() - c = session.query(Company).get(id) - for e in c.employees: - print e, e._instance_key + c = session.query(Company).get(id) + for e in c.employees: + print e, e._instance_key - session.delete(c) - session.flush() + session.delete(c) + session.flush() - RoundTripTest.__name__ = "Test%s%s%s%s" % ( - (lazy_relation and "Lazy" or "Eager"), - (include_base and "Inclbase" or ""), - (redefine_colprop and "Redefcol" or ""), - (not use_union and "Nounion" or (use_literal_join and "Litjoin" or "")) + test_roundtrip.__name__ = "test_%s%s%s%s" % ( + (lazy_relation and "lazy" or "eager"), + (include_base and "_inclbase" or ""), + (redefine_colprop and "_redefcol" or ""), + (not use_union and "_nounion" or (use_literal_join and "_litjoin" or "")) ) - return RoundTripTest + setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip) for include_base in [True, False]: for lazy_relation in [True, False]: for redefine_colprop in [True, False]: for use_literal_join in [True, False]: for use_union in [True, False]: - testclass = generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, use_union) - exec("%s = testclass" % testclass.__name__) + generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, use_union) if __name__ == "__main__": testbase.main() diff --git a/test/orm/inheritance/polymorph2.py b/test/orm/inheritance/polymorph2.py index 1c42146356..fb704a495c 100644 --- a/test/orm/inheritance/polymorph2.py +++ b/test/orm/inheritance/polymorph2.py @@ -329,6 +329,9 @@ class RelationTest4(testbase.ORMTest): manager_mapper = mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') car_mapper = mapper(Car, cars, properties= {'employee':relation(person_mapper)}) + print class_mapper(Person).primary_key + print person_mapper.get_select_mapper().primary_key + # so the primaryjoin is "people.c.person_id==cars.c.owner". the "lazy" clause will be # "people.c.person_id=?". the employee_join is two selects union'ed together, one of which # will contain employee.c.person_id the other contains manager.c.person_id. people.c.person_id is not explicitly in @@ -361,9 +364,13 @@ class RelationTest4(testbase.ORMTest): session.clear() + print "----------------------------" car1 = session.query(Car).get(car1.car_id) + print "----------------------------" usingGet = session.query(person_mapper).get(car1.owner) + print "----------------------------" usingProperty = car1.employee + print "----------------------------" # All print should output the same person (engineer E4) assert str(engineer4) == "Engineer E4, status X" -- 2.47.3