From: Mike Bayer Date: Mon, 14 Apr 2008 15:49:39 +0000 (+0000) Subject: - added experimental relation() flag to help with primaryjoins X-Git-Tag: rel_0_5beta1~185 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1d1484b210dba39dac47ab5a35f4336de6dbf9ec;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - added experimental relation() flag to help with primaryjoins across functions, etc., _local_remote_pairs=[tuples]. This complements a complex primaryjoin condition allowing you to provide the individual column pairs which comprise the relation's local and remote sides. --- diff --git a/CHANGES b/CHANGES index fc2c38b0ef..37861592d7 100644 --- a/CHANGES +++ b/CHANGES @@ -13,6 +13,12 @@ CHANGES - Also re-established viewonly relation() configurations that join across multiple tables. + - added experimental relation() flag to help with primaryjoins + across functions, etc., _local_remote_pairs=[tuples]. + This complements a complex primaryjoin condition allowing + you to provide the individual column pairs which comprise + the relation's local and remote sides. + - removed ancient assertion that mapped selectables require "alias names" - the mapper creates its own alias now if none is present. Though in this case you need to use diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 560c304da2..5508b5c8e3 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -185,11 +185,11 @@ class Mapper(object): def __log(self, msg): if self.__should_log_info: - self.logger.info("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (not self.non_primary and "|non-primary" or "") + ") " + msg) + self.logger.info("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") + ") " + msg) def __log_debug(self, msg): if self.__should_log_debug: - self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (not self.non_primary and "|non-primary" or "") + ") " + msg) + self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") + ") " + msg) def _is_orphan(self, obj): o = False diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index c4bb323de5..0eece6466f 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -210,7 +210,7 @@ class PropertyLoader(StrategizedProperty): of items that correspond to a related database table. """ - def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, passive_updates=True, remote_side=None, enable_typechecks=True, join_depth=None, strategy_class=None): + def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, passive_updates=True, remote_side=None, enable_typechecks=True, join_depth=None, strategy_class=None, _local_remote_pairs=None): self.uselist = uselist self.argument = argument self.entity_name = entity_name @@ -232,6 +232,7 @@ class PropertyLoader(StrategizedProperty): self.enable_typechecks = enable_typechecks self.comparator = PropertyLoader.Comparator(self) self.join_depth = join_depth + self._arg_local_remote_pairs = _local_remote_pairs if strategy_class: self.strategy_class = strategy_class @@ -538,26 +539,39 @@ class PropertyLoader(StrategizedProperty): self.secondary.c.contains_column(column) is not None def __determine_fks(self): + if self._legacy_foreignkey and not self._refers_to_parent_table(): self.foreign_keys = self._legacy_foreignkey arg_foreign_keys = self.foreign_keys + + if self._arg_local_remote_pairs: + if not arg_foreign_keys: + raise exceptions.ArgumentError("foreign_keys argument is required with _local_remote_pairs argument") + self.foreign_keys = util.OrderedSet(arg_foreign_keys) + self._opposite_side = util.OrderedSet() + for l, r in self._arg_local_remote_pairs: + if r in self.foreign_keys: + self._opposite_side.add(l) + elif l in self.foreign_keys: + self._opposite_side.add(r) + self.synchronize_pairs = zip(self._opposite_side, self.foreign_keys) + else: + eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=self.viewonly) + eq_pairs = [(l, r) for l, r in eq_pairs if (self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)) or r in arg_foreign_keys] + + if not eq_pairs: + if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True): + raise exceptions.ArgumentError("Could not locate any equated, locally mapped column pairs for primaryjoin condition '%s' on relation %s. " + "For more relaxed rules on join conditions, the relation may be marked as viewonly=True." % (self.primaryjoin, self) + ) + else: + raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. " + "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.primaryjoin, self)) - eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=self.viewonly) - eq_pairs = [(l, r) for l, r in eq_pairs if (self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)) or r in arg_foreign_keys] - - if not eq_pairs: - if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True): - raise exceptions.ArgumentError("Could not locate any equated, locally mapped column pairs for primaryjoin condition '%s' on relation %s. " - "For more relaxed rules on join conditions, the relation may be marked as viewonly=True." % (self.primaryjoin, self) - ) - else: - raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. " - "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.primaryjoin, self)) - - self.foreign_keys = util.OrderedSet([r for l, r in eq_pairs]) - self._opposite_side = util.OrderedSet([l for l, r in eq_pairs]) - self.synchronize_pairs = eq_pairs + self.foreign_keys = util.OrderedSet([r for l, r in eq_pairs]) + self._opposite_side = util.OrderedSet([l for l, r in eq_pairs]) + self.synchronize_pairs = eq_pairs if self.secondaryjoin: sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=self.viewonly) @@ -579,7 +593,14 @@ class PropertyLoader(StrategizedProperty): self.secondary_synchronize_pairs = None def __determine_remote_side(self): - if self.remote_side: + if self._arg_local_remote_pairs: + if self.remote_side: + raise exceptions.ArgumentError("remote_side argument is redundant against more detailed _local_remote_side argument.") + if self.direction is MANYTOONE: + eq_pairs = [(r, l) for l, r in self._arg_local_remote_pairs] + else: + eq_pairs = self._arg_local_remote_pairs + elif self.remote_side: if self.direction is MANYTOONE: eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_referenced_keys=self.remote_side, any_operator=True) else: @@ -594,11 +615,11 @@ class PropertyLoader(StrategizedProperty): eq_pairs += sq_pairs eq_pairs = [(l, r) for l, r in eq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)] - self.local_remote_pairs = eq_pairs if self.direction is MANYTOONE: self.remote_side, self.local_side = [util.OrderedSet(s) for s in zip(*eq_pairs)] else: self.local_side, self.remote_side = [util.OrderedSet(s) for s in zip(*eq_pairs)] + self.local_remote_pairs = zip(self.local_side, self.remote_side) if self.direction is ONETOMANY: for l in self.local_side: diff --git a/test/orm/relationships.py b/test/orm/relationships.py index 22b69af2aa..6e9a0f96e1 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -1052,7 +1052,101 @@ class ViewOnlyTest6(ORMTest): }) mapper(T3, t3) self.assertRaisesMessage(exceptions.ArgumentError, "Specify remote_side argument", compile_mappers) + +class ExplicitLocalRemoteTest(ORMTest): + def define_tables(self, metadata): + global t1, t2 + t1 = Table('t1', metadata, + Column('id', String(50), primary_key=True), + Column('data', String(50)) + ) + t2 = Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column('t1id', String(50)), + ) + + def test_onetomany(self): + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + + # use a function within join condition. but specifying + # local_remote_pairs overrides all parsing of the join condition. + mapper(T1, t1, properties={ + 't2s':relation(T2, primaryjoin=t1.c.id==func.lower(t2.c.t1id), + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + foreign_keys=[t2.c.t1id] + ) + }) + mapper(T2, t2) + + sess = create_session() + a1 = T1(id='number1', data='a1') + b1 = T2(data='b1', t1id='NuMbEr1') + b2 = T2(data='b2', t1id='Number1') + sess.save(a1) + sess.save(b1) + sess.save(b2) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(T1).first(), T1(id='number1', data='a1', t2s=[T2(data='b1', t1id='NuMbEr1'), T2(data='b2', t1id='Number1')])) + + def test_manytoone(self): + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + mapper(T1, t1) + mapper(T2, t2, properties={ + 't1':relation(T1, primaryjoin=t1.c.id==func.lower(t2.c.t1id), + _local_remote_pairs=[(t2.c.t1id, t1.c.id)], + foreign_keys=[t2.c.t1id] + ) + }) + sess = create_session() + a1 = T1(id='number1', data='a1') + b1 = T2(data='b1', t1id='NuMbEr1') + b2 = T2(data='b2', t1id='Number1') + sess.save(a1) + sess.save(b1) + sess.save(b2) + sess.flush() + sess.clear() + self.assertEquals(sess.query(T2).all(), + [ + T2(data='b1', t1=T1(id='number1', data='a1')), + T2(data='b2', t1=T1(id='number1', data='a1')) + ] + ) + def test_escalation(self): + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + + mapper(T1, t1, properties={ + 't2s':relation(T2, primaryjoin=t1.c.id==func.lower(t2.c.t1id), + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + foreign_keys=[t2.c.t1id], + remote_side=[t2.c.t1id] + ) + }) + mapper(T2, t2) + self.assertRaises(exceptions.ArgumentError, compile_mappers) + + clear_mappers() + mapper(T1, t1, properties={ + 't2s':relation(T2, primaryjoin=t1.c.id==func.lower(t2.c.t1id), + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + ) + }) + mapper(T2, t2) + self.assertRaises(exceptions.ArgumentError, compile_mappers) + class InvalidRelationEscalationTest(ORMTest): def define_tables(self, metadata): global foos, bars, Foo, Bar