From: Mike Bayer Date: Mon, 23 Jul 2007 22:20:44 +0000 (+0000) Subject: - joined-table inheritance will now generate the primary key X-Git-Tag: rel_0_4_6~38 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5130d25fd144f7434a49a5fa7000b468bbd939c2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - joined-table inheritance will now generate the primary key columns of all inherited classes against the root table of the join only. This implies that each row in the root table is distinct to a single instance. If for some rare reason this is not desireable, explicit primary_key settings on individual mappers will override it. - When "polymorphic" flags are used with joined-table or single-table inheritance, all identity keys are generated against the root class of the inheritance hierarchy; this allows query.get() to work polymorphically using the same caching semantics as a non-polymorphic get. note that this currently does not work with concrete inheritance. --- diff --git a/CHANGES b/CHANGES index 4834e36626..092ee176ec 100644 --- a/CHANGES +++ b/CHANGES @@ -104,6 +104,18 @@ should "collapse" into a single-valued (or fewer-valued) primary key. fixes things like [ticket:611]. + - joined-table inheritance will now generate the primary key + columns of all inherited classes against the root table of the + join only. This implies that each row in the root table is distinct + to a single instance. If for some rare reason this is not desireable, + explicit primary_key settings on individual mappers will override it. + + - When "polymorphic" flags are used with joined-table or single-table + inheritance, all identity keys are generated against the root class + of the inheritance hierarchy; this allows query.get() to work + polymorphically using the same caching semantics as a non-polymorphic get. + note that this currently does not work with concrete inheritance. + - secondary inheritance loading: polymorphic mappers can be constructed *without* a select_table argument. inheriting mappers whose tables were not represented in the initial load will issue a diff --git a/examples/polymorph/concrete.py b/examples/polymorph/concrete.py index 593d3f4805..5f12e9a3d7 100644 --- a/examples/polymorph/concrete.py +++ b/examples/polymorph/concrete.py @@ -1,4 +1,5 @@ from sqlalchemy import * +from sqlalchemy.orm import * metadata = MetaData() @@ -49,7 +50,7 @@ manager_mapper = mapper(Manager, managers_table, inherits=employee_mapper, concr engineer_mapper = mapper(Engineer, engineers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='engineer') -session = create_session(bind_to=engine) +session = create_session(bind=engine) m1 = Manager("pointy haired boss", "manager1") e1 = Engineer("wally", "engineer1") diff --git a/examples/polymorph/single.py b/examples/polymorph/single.py index dcdb3c8906..61809a05c1 100644 --- a/examples/polymorph/single.py +++ b/examples/polymorph/single.py @@ -1,4 +1,5 @@ from sqlalchemy import * +from sqlalchemy.orm import * metadata = MetaData('sqlite://', echo='debug') diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 097c906abf..555c1990e5 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -469,8 +469,17 @@ class Mapper(object): self.mapped_table = self.local_table if self.polymorphic_identity is not None: self.inherits._add_polymorphic_mapping(self.polymorphic_identity, self) - if self.polymorphic_on is None and self.inherits.polymorphic_on is not None: - self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on, keys_ok=True, raiseerr=False) + if self.polymorphic_on is None: + if self.inherits.polymorphic_on is not None: + self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on, keys_ok=True, raiseerr=False) + else: + raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) + + if self.polymorphic_identity is not None and not self.concrete: + self._identity_class = self.inherits._identity_class + else: + self._identity_class = self.class_ + if self.order_by is False: self.order_by = self.inherits.order_by self.polymorphic_map = self.inherits.polymorphic_map @@ -480,8 +489,11 @@ class Mapper(object): self._synchronizer = None self.mapped_table = self.local_table if self.polymorphic_identity is not None: + if self.polymorphic_on is None: + raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) self._add_polymorphic_mapping(self.polymorphic_identity, self) - + self._identity_class = self.class_ + if self.mapped_table is None: raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self)) @@ -540,58 +552,60 @@ class Mapper(object): if len(self.pks_by_table[self.mapped_table]) == 0: raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) - # create the "primary_key" for this mapper. this will flatten "equivalent" primary key columns - # into one column, where "equivalent" means that one column references the other via foreign key, or - # multiple columns that all reference a common parent column. it will also resolve the column - # against the "mapped_table" of this mapper. - equivalent_columns = self._get_equivalent_columns() + if self.inherits is not None and not self.concrete and not self.primary_key_argument: + self.primary_key = self.inherits.primary_key + self._get_clause = self.inherits._get_clause + else: + # create the "primary_key" for this mapper. this will flatten "equivalent" primary key columns + # into one column, where "equivalent" means that one column references the other via foreign key, or + # multiple columns that all reference a common parent column. it will also resolve the column + # against the "mapped_table" of this mapper. + equivalent_columns = self._get_equivalent_columns() - primary_key = sql.ColumnSet() - - for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]): - #primary_key.add(col) - #continue - c = self.mapped_table.corresponding_column(col, raiseerr=False) - if c is None: - for cc in equivalent_columns[col]: - c = self.mapped_table.corresponding_column(cc, raiseerr=False) - if c is not None: + primary_key = sql.ColumnSet() + + for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]): + c = self.mapped_table.corresponding_column(col, raiseerr=False) + if c is None: + for cc in equivalent_columns[col]: + c = self.mapped_table.corresponding_column(cc, raiseerr=False) + if c is not None: + break + else: + raise exceptions.ArgumentError("Cant resolve column " + str(col)) + + # this step attempts to resolve the column to an equivalent which is not + # a foreign key elsewhere. this helps with joined table inheritance + # so that PKs are expressed in terms of the base table which is always + # present in the initial select + # TODO: this is a little hacky right now, the "tried" list is to prevent + # endless loops between cyclical FKs, try to make this cleaner/work better/etc., + # perhaps via topological sort (pick the leftmost item) + tried = util.Set() + while True: + if not len(c.foreign_keys) or c in tried: break - else: - raise exceptions.ArgumentError("Cant resolve column " + str(col)) - - # this step attempts to resolve the column to an equivalent which is not - # a foreign key elsewhere. this helps with joined table inheritance - # so that PKs are expressed in terms of the base table which is always - # present in the initial select - # TODO: this is a little hacky right now, the "tried" list is to prevent - # endless loops between cyclical FKs, try to make this cleaner/work better/etc., - # perhaps via topological sort (pick the leftmost item) - tried = util.Set() - while True: - if not len(c.foreign_keys) or c in tried: - break - for cc in c.foreign_keys: - cc = cc.column - c2 = self.mapped_table.corresponding_column(cc, raiseerr=False) - if c2 is not None: - c = c2 - tried.add(c) + for cc in c.foreign_keys: + cc = cc.column + c2 = self.mapped_table.corresponding_column(cc, raiseerr=False) + if c2 is not None: + c = c2 + tried.add(c) + break + else: break - else: - break - primary_key.add(c) + primary_key.add(c) - if len(primary_key) == 0: - raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) + if len(primary_key) == 0: + raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) - self.primary_key = primary_key - self.__log("Identified primary key columns: " + str(primary_key)) + self.primary_key = primary_key + self.__log("Identified primary key columns: " + str(primary_key)) - _get_clause = sql.and_() - for primary_key in self.primary_key: - _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type_=primary_key.type, unique=True)) - self._get_clause = _get_clause + _get_clause = sql.and_() + for primary_key in self.primary_key: + _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type_=primary_key.type, unique=True)) + self._get_clause = _get_clause def _get_equivalent_columns(self): """Create a map of all *equivalent* columns, based on @@ -996,7 +1010,7 @@ class Mapper(object): dictionary corresponding result-set ``ColumnElement`` instances to their values within a row. """ - return (self.class_, tuple([row[column] for column in self.primary_key]), self.entity_name) + return (self._identity_class, tuple([row[column] for column in self.primary_key]), self.entity_name) def identity_key_from_primary_key(self, primary_key): """Return an identity-map key for use in storing/retrieving an @@ -1005,7 +1019,7 @@ class Mapper(object): primary_key A list of values indicating the identifier. """ - return (self.class_, tuple(util.to_list(primary_key)), self.entity_name) + return (self._identity_class, tuple(util.to_list(primary_key)), self.entity_name) def identity_key_from_instance(self, instance): """Return the identity key for the given instance, based on diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py index f31d0013c2..d91fbe4b52 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql_util.py @@ -147,7 +147,6 @@ class AbstractClauseProcessor(sql.NoColumnVisitor): def visit_clauselist(self, clist): for i in range(0, len(clist.clauses)): n = self.convert_element(clist.clauses[i]) - print "CONVERTEING CLAUSELIST W ID", id(clist) if n is not None: clist.clauses[i] = n diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index c6cd43f439..be623e1b87 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -62,9 +62,93 @@ class O2MTest(ORMTest): self.assert_(compare == result) self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1') -class AddPropTest(ORMTest): - """testing that construction of inheriting mappers works regardless of when extra properties - are added to the superclass mapper""" +class GetTest(ORMTest): + def define_tables(self, metadata): + global foo, bar, blub + foo = Table('foo', metadata, + Column('id', Integer, Sequence('foo_seq'), primary_key=True), + Column('type', String(30)), + Column('data', String(20))) + + bar = Table('bar', metadata, + Column('id', Integer, ForeignKey('foo.id'), primary_key=True), + Column('data', String(20))) + + blub = Table('blub', metadata, + Column('id', Integer, primary_key=True), + Column('foo_id', Integer, ForeignKey('foo.id')), + Column('bar_id', Integer, ForeignKey('bar.id')), + Column('data', String(20))) + + def create_test(polymorphic): + def test_get(self): + class Foo(object): + pass + + class Bar(Foo): + pass + + class Blub(Bar): + pass + + if polymorphic: + mapper(Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity='foo') + mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar') + mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub') + else: + mapper(Foo, foo) + mapper(Bar, bar, inherits=Foo) + mapper(Blub, blub, inherits=Bar) + + sess = create_session() + f = Foo() + b = Bar() + bl = Blub() + sess.save(f) + sess.save(b) + sess.save(bl) + sess.flush() + + if polymorphic: + def go(): + assert sess.query(Foo).get(f.id) == f + assert sess.query(Foo).get(b.id) == b + assert sess.query(Foo).get(bl.id) == bl + assert sess.query(Bar).get(b.id) == b + assert sess.query(Bar).get(bl.id) == bl + assert sess.query(Blub).get(bl.id) == bl + + self.assert_sql_count(testbase.db, go, 0) + else: + # this is testing the 'wrong' behavior of using get() + # polymorphically with mappers that are not configured to be + # polymorphic. the important part being that get() always + # returns an instance of the query's type. + def go(): + assert sess.query(Foo).get(f.id) == f + + bb = sess.query(Foo).get(b.id) + assert isinstance(b, Foo) and bb.id==b.id + + bll = sess.query(Foo).get(bl.id) + assert isinstance(bll, Foo) and bll.id==bl.id + + assert sess.query(Bar).get(b.id) == b + + bll = sess.query(Bar).get(bl.id) + assert isinstance(bll, Bar) and bll.id == bl.id + + assert sess.query(Blub).get(bl.id) == bl + + self.assert_sql_count(testbase.db, go, 3) + + return test_get + + test_get_polymorphic = create_test(True) + test_get_nonpolymorphic = create_test(False) + + +class ConstructionTest(ORMTest): def define_tables(self, metadata): global content_type, content, product content_type = Table('content_type', metadata, @@ -72,7 +156,8 @@ class AddPropTest(ORMTest): ) content = Table('content', metadata, Column('id', Integer, primary_key=True), - Column('content_type_id', Integer, ForeignKey('content_type.id')) + Column('content_type_id', Integer, ForeignKey('content_type.id')), + Column('type', String(30)) ) product = Table('product', metadata, Column('id', Integer, ForeignKey('content.id'), primary_key=True) @@ -86,11 +171,15 @@ class AddPropTest(ORMTest): content_types = mapper(ContentType, content_type) contents = mapper(Content, content, properties={ 'content_type':relation(content_types) - }) - #contents.add_property('content_type', relation(content_types)) #adding this makes the inheritance stop working - # shouldnt throw exception - products = mapper(Product, product, inherits=contents) - # TODO: assertion ?? + }, polymorphic_identity='contents') + + products = mapper(Product, product, inherits=contents, polymorphic_identity='products') + + try: + compile_mappers() + assert False + except exceptions.ArgumentError, e: + assert str(e) == "Mapper 'Mapper|Content|content' specifies a polymorphic_identity of 'contents', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" def testbackref(self): """tests adding a property to the superclass mapper""" @@ -98,8 +187,8 @@ class AddPropTest(ORMTest): class Content(object): pass class Product(Content): pass - contents = mapper(Content, content) - products = mapper(Product, product, inherits=contents) + contents = mapper(Content, content, polymorphic_on=content.c.type, polymorphic_identity='content') + products = mapper(Product, product, inherits=contents, polymorphic_identity='product') content_types = mapper(ContentType, content_type, properties={ 'content':relation(contents, backref='contenttype') }) @@ -278,12 +367,7 @@ class DistinctPKTest(ORMTest): def test_implicit(self): person_mapper = mapper(Person, person_table) mapper(Employee, employee_table, inherits=person_mapper) - try: - print class_mapper(Employee).primary_key - assert list(class_mapper(Employee).primary_key) == [person_table.c.id, employee_table.c.id] - assert False - except RuntimeWarning, e: - assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name." + assert list(class_mapper(Employee).primary_key) == [person_table.c.id] def test_explicit_props(self): person_mapper = mapper(Person, person_table) diff --git a/test/orm/inheritance/concrete.py b/test/orm/inheritance/concrete.py index 167b25256d..d95a96da5f 100644 --- a/test/orm/inheritance/concrete.py +++ b/test/orm/inheritance/concrete.py @@ -54,6 +54,7 @@ class ConcreteTest1(ORMTest): session.flush() session.clear() + print set([repr(x) for x in session.query(Employee).select()]) assert set([repr(x) for x in session.query(Employee).select()]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"]) assert set([repr(x) for x in session.query(Manager).select()]) == set(["Manager Tom knows how to manage things"]) assert set([repr(x) for x in session.query(Engineer).select()]) == set(["Engineer Kurt knows how to hack"])