From: Mike Bayer Date: Fri, 30 Nov 2007 00:40:56 +0000 (+0000) Subject: - added support for version_id_col in conjunction with inheriting mappers. X-Git-Tag: rel_0_4_2~120 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=2b306350e24e4e7e304c85b9ea5e746b6b38adb2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - added support for version_id_col in conjunction with inheriting mappers. version_id_col is typically set on the base mapper in an inheritance relationship where it takes effect for all inheriting mappers. [ticket:883] - a little rearrangement of save_obj() --- diff --git a/CHANGES b/CHANGES index dded89564d..1dc797d77e 100644 --- a/CHANGES +++ b/CHANGES @@ -28,6 +28,11 @@ CHANGES i.e.: 'somename':synonym('_somename', map_column=True) will map the column named 'somename' to the attribute '_somename'. See the example in the mapper docs. [ticket:801] + + - added support for version_id_col in conjunction with inheriting mappers. + version_id_col is typically set on the base mapper in an inheritance + relationship where it takes effect for all inheriting mappers. + [ticket:883] - fixed endless loop issue when using lazy="dynamic" on both sides of a bi-directional relationship [ticket:872] diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 570251c69c..52523d7876 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -338,6 +338,9 @@ class Mapper(object): self._identity_class = self.inherits._identity_class else: self._identity_class = self.class_ + + if self.version_id_col is None: + self.version_id_col = self.inherits.version_id_col if self.order_by is False: self.order_by = self.inherits.order_by @@ -405,12 +408,9 @@ class Mapper(object): # and assemble primary key columns for t in self.tables + [self.mapped_table]: self._all_tables.add(t) - try: - l = self.pks_by_table[t] - except KeyError: - l = self.pks_by_table.setdefault(t, util.OrderedSet()) - for k in t.primary_key: - l.add(k) + if t not in self.pks_by_table: + self.pks_by_table[t] = util.OrderedSet() + self.pks_by_table[t].update(t.primary_key) if self.primary_key_argument is not None: for k in self.primary_key_argument: @@ -974,8 +974,7 @@ class Mapper(object): # UPDATE if so. mapper = object_mapper(obj) instance_key = mapper.identity_key_from_instance(obj) - is_row_switch = not postupdate and not has_identity(obj) and instance_key in uowtransaction.uow.identity_map - if is_row_switch: + if not postupdate and not has_identity(obj) and instance_key in uowtransaction.uow.identity_map: existing = uowtransaction.uow.identity_map[instance_key] if not uowtransaction.is_deleted(existing): raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (mapperutil.instance_str(obj), str(instance_key), mapperutil.instance_str(existing))) @@ -1003,6 +1002,7 @@ class Mapper(object): mapper = object_mapper(obj) if table not in mapper.tables or not mapper._has_pks(table): continue + pks = mapper.pks_by_table[table] instance_key = mapper.identity_key_from_instance(obj) if self.__should_log_debug: self.__log_debug("save_obj() table '%s' instance %s identity %s" % (table.name, mapperutil.instance_str(obj), str(instance_key))) @@ -1011,46 +1011,46 @@ class Mapper(object): params = {} value_params = {} hasdata = False - for col in table.columns: - if col is mapper.version_id_col: - if not isinsert: - params[col._label] = mapper.get_attr_by_column(obj, col) - params[col.key] = params[col._label] + 1 - else: + + if isinsert: + for col in table.columns: + if col is mapper.version_id_col: params[col.key] = 1 - elif col in mapper.pks_by_table[table]: - # column is a primary key ? - if not isinsert: - # doing an UPDATE? put primary key values as "WHERE" parameters - # matching the bindparam we are creating below, i.e. "_" - params[col._label] = mapper.get_attr_by_column(obj, col) - else: - # doing an INSERT, primary key col ? - # if the primary key values are not populated, - # leave them out of the INSERT altogether, since PostGres doesn't want - # them to be present for SERIAL to take effect. A SQLEngine that uses - # explicit sequences will put them back in if they are needed + elif col in pks: value = mapper.get_attr_by_column(obj, col) if value is not None: params[col.key] = value - elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col): - if isinsert: + elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col): if self.__should_log_debug: self.__log_debug("Using polymorphic identity '%s' for insert column '%s'" % (mapper.polymorphic_identity, col.key)) value = mapper.polymorphic_identity if col.default is None or value is not None: params[col.key] = value - else: - # column is not a primary key ? - if not isinsert: - # doing an UPDATE ? get the history for the attribute, with "passive" - # so as not to trigger any deferred loads. if there is a new - # value, add it to the bind parameters - if post_update_cols is not None and col not in post_update_cols: + else: + value = mapper.get_attr_by_column(obj, col, False) + if value is NO_ATTRIBUTE: continue - elif is_row_switch: - params[col.key] = self.get_attr_by_column(obj, col) - hasdata = True + if col.default is None or value is not None: + if isinstance(value, sql.ClauseElement): + value_params[col] = value + else: + params[col.key] = value + insert.append((obj, params, mapper, connection, value_params)) + else: + for col in table.columns: + if col is mapper.version_id_col: + params[col._label] = mapper.get_attr_by_column(obj, col) + params[col.key] = params[col._label] + 1 + for prop in mapper._columntoproperty.values(): + history = prop.get_history(obj, passive=True) + if history and history.added_items(): + hasdata = True + elif col in pks: + params[col._label] = mapper.get_attr_by_column(obj, col) + elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col): + pass + else: + if post_update_cols is not None and col not in post_update_cols: continue prop = mapper._getpropbycolumn(col, False) if prop is None: @@ -1064,38 +1064,17 @@ class Mapper(object): else: params[col.key] = prop.get_col_value(col, a[0]) hasdata = True - else: - # doing an INSERT, non primary key col ? - # add the attribute's value to the - # bind parameters, unless its None and the column has a - # default. if its None and theres no default, we still might - # not want to put it in the col list but SQLIte doesnt seem to like that - # if theres no columns at all - value = mapper.get_attr_by_column(obj, col, False) - if value is NO_ATTRIBUTE: - continue - if col.default is None or value is not None: - # TODO: clauseelments as bind params should - # be handled by Insert/Update expression upon execute() - if isinstance(value, sql.ClauseElement): - value_params[col] = value - else: - params[col.key] = value - - if not isinsert: if hasdata: # if none of the attributes changed, dont even # add the row to be updated. update.append((obj, params, mapper, connection, value_params)) - else: - insert.append((obj, params, mapper, connection, value_params)) if update: mapper = table_to_mapper[table] clause = sql.and_() for col in mapper.pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col._label, type_=col.type, unique=True)) - if mapper.version_id_col is not None: + if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type, unique=True)) statement = table.update(clause) rows = 0 @@ -1230,7 +1209,7 @@ class Mapper(object): delete.setdefault(connection, []).append(params) for col in mapper.pks_by_table[table]: params[col.key] = mapper.get_attr_by_column(obj, col) - if mapper.version_id_col is not None: + if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col) # testlib.pragma exempt:__hash__ deleted_objects.add((id(obj), obj, connection)) @@ -1246,7 +1225,7 @@ class Mapper(object): clause = sql.and_() for col in mapper.pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True)) - if mapper.version_id_col is not None: + if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True)) statement = table.delete(clause) c = connection.execute(statement, del_objects) @@ -1259,6 +1238,7 @@ class Mapper(object): mapper.extension.after_delete(mapper, connection, obj) def _has_pks(self, table): + # TODO: determine this beforehand if self.pks_by_table.get(table, None): for k in self.pks_by_table[table]: if k not in self._columntoproperty: diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index 68d7956cc5..0301d01c99 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -3,7 +3,7 @@ from sqlalchemy import * from sqlalchemy import exceptions, util from sqlalchemy.orm import * from testlib import * - +from testlib import fixtures class O2MTest(ORMTest): """deals with inheritance and one-to-many relationships""" @@ -332,6 +332,109 @@ class FlushTest(ORMTest): sess.flush() assert user_roles.count().scalar() == 1 +class VersioningTest(ORMTest): + def define_tables(self, metadata): + global base, subtable, stuff + base = Table('base', metadata, + Column('id', Integer, Sequence('version_test_seq', optional=True), primary_key=True ), + Column('version_id', Integer, nullable=False), + Column('value', String(40)), + Column('discriminator', Integer, nullable=False) + ) + subtable = Table('subtable', metadata, + Column('id', None, ForeignKey('base.id'), primary_key=True), + Column('subdata', String(50)) + ) + stuff = Table('stuff', metadata, + Column('id', Integer, primary_key=True), + Column('parent', Integer, ForeignKey('base.id')) + ) + + @engines.close_open_connections + def test_save_update(self): + class Base(fixtures.Base): + pass + class Sub(Base): + pass + class Stuff(Base): + pass + mapper(Stuff, stuff) + mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1, properties={ + 'stuff':relation(Stuff) + }) + mapper(Sub, subtable, inherits=Base, polymorphic_identity=2) + + sess = create_session() + + b1 = Base(value='b1') + s1 = Sub(value='sub1', subdata='some subdata') + sess.save(b1) + sess.save(s1) + + sess.flush() + + sess2 = create_session() + s2 = sess2.query(Base).get(s1.id) + s2.subdata = 'sess2 subdata' + + s1.subdata = 'sess1 subdata' + + sess.flush() + + try: + sess2.query(Base).with_lockmode('read').get(s1.id) + assert False + except exceptions.ConcurrentModificationError, e: + assert True + + try: + sess2.flush() + assert False + except exceptions.ConcurrentModificationError, e: + assert True + + sess2.refresh(s2) + assert s2.subdata == 'sess1 subdata' + s2.subdata = 'sess2 subdata' + sess2.flush() + + def test_delete(self): + class Base(fixtures.Base): + pass + class Sub(Base): + pass + + mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1) + mapper(Sub, subtable, inherits=Base, polymorphic_identity=2) + + sess = create_session() + + b1 = Base(value='b1') + s1 = Sub(value='sub1', subdata='some subdata') + s2 = Sub(value='sub2', subdata='some other subdata') + sess.save(b1) + sess.save(s1) + sess.save(s2) + + sess.flush() + + sess2 = create_session() + s3 = sess2.query(Base).get(s1.id) + sess2.delete(s3) + sess2.flush() + + s2.subdata = 'some new subdata' + sess.flush() + + try: + s1.subdata = 'some new subdata' + sess.flush() + assert False + except exceptions.ConcurrentModificationError, e: + assert True + + + class DistinctPKTest(ORMTest): """test the construction of mapper.primary_key when an inheriting relationship joins on a column other than primary key column."""