From 6aae45aeaa008c034fc57d794a8f2a6ee2218dd3 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 26 Feb 2006 22:23:01 +0000 Subject: [PATCH] implemented SyncRules for mapper with inheritance relationship, fixes [ticket:81] TableFinder becomes a list-implementing object (should probably create clauseutils or sqlutils for these little helper visitors) --- lib/sqlalchemy/mapping/mapper.py | 26 +++++++++++++++++++------- lib/sqlalchemy/mapping/properties.py | 4 +++- lib/sqlalchemy/mapping/sync.py | 5 ++--- test/inheritance.py | 6 ++++-- 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index 4541c899a9..a4de28a2da 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -9,6 +9,7 @@ import sqlalchemy.sql as sql import sqlalchemy.schema as schema import sqlalchemy.engine as engine import sqlalchemy.util as util +import sync from sqlalchemy.exceptions import * import objectstore import sys @@ -64,14 +65,15 @@ class Mapper(object): self.primarytable = inherits.primarytable # inherit_condition is optional since the join can figure it out self.table = sql.join(inherits.table, table, inherit_condition) + self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY) + self._synchronizer.compile(self.table.onclause, inherits.tables, TableFinder(table)) else: self.primarytable = self.table - + self._synchronizer = None + # locate all tables contained within the "table" passed in, which # may be a join or other construct - tf = TableFinder() - self.table.accept_visitor(tf) - self.tables = tf.tables + self.tables = TableFinder(self.table) # determine primary key columns, either passed in, or get them from our set of tables self.pks_by_table = {} @@ -170,7 +172,6 @@ class Mapper(object): self.props[key] = prop.copy() self.props[key].parent = self self.props[key].key = None # force re-init - l = [(key, prop) for key, prop in self.props.iteritems()] for key, prop in l: if getattr(prop, 'key', None) is None: @@ -589,6 +590,8 @@ class Mapper(object): for c in table.c: if self._getattrbycolumn(obj, c) is None: self._setattrbycolumn(obj, c, row[c]) + if self._synchronizer is not None: + self._synchronizer.execute(obj, obj) self.extension.after_insert(self, obj) def delete_obj(self, objects, uow): @@ -878,11 +881,20 @@ class MapperExtension(object): class TableFinder(sql.ClauseVisitor): """given a Clause, locates all the Tables within it into a list.""" - def __init__(self): + def __init__(self, table): self.tables = [] + table.accept_visitor(self) def visit_table(self, table): self.tables.append(table) - + def __getitem__(self, i): + return self.tables[i] + def __iter__(self): + return iter(self.tables) + def __contains__(self, obj): + return obj in self.tables + def __add__(self, obj): + return self.tables + obj + def hash_key(obj): if obj is None: return 'None' diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/mapping/properties.py index de782f17d1..85fd7418b2 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/mapping/properties.py @@ -49,7 +49,9 @@ class ColumnProperty(MapperProperty): def execute(self, instance, row, identitykey, imap, isnew): if isnew: instance.__dict__[self.key] = row[self.columns[0]] - + def __repr__(self): + return "ColumnProperty(%s)" % repr([str(c) for c in self.columns]) + class DeferredColumnProperty(ColumnProperty): """describes an object attribute that corresponds to a table column, which also will "lazy load" its value from the table. this is per-column lazy loading.""" diff --git a/lib/sqlalchemy/mapping/sync.py b/lib/sqlalchemy/mapping/sync.py index d5c1f87bf4..b322780118 100644 --- a/lib/sqlalchemy/mapping/sync.py +++ b/lib/sqlalchemy/mapping/sync.py @@ -1,7 +1,6 @@ import sqlalchemy.sql as sql import sqlalchemy.schema as schema from sqlalchemy.exceptions import * -import properties """contains the ClauseSynchronizer class which is used to map attributes between two objects in a manner corresponding to a SQL clause that compares column values.""" @@ -74,7 +73,7 @@ class ClauseSynchronizer(object): if len(self.syncrules) == rules_added: raise ArgumentError("No syncrules generated for join criterion " + str(sqlclause)) - def execute(self, source, dest, obj, child, clearkeys): + def execute(self, source, dest, obj=None, child=None, clearkeys=None): for rule in self.syncrules: rule.execute(source, dest, obj, child, clearkeys) @@ -110,7 +109,7 @@ class SyncRule(object): if isinstance(dest, dict): dest[self.dest_column.key] = value else: - #print "SYNC VALUE", value, "TO", dest + #print "SYNC VALUE", value, "TO", dest, self.source_column, self.dest_column self.dest_mapper._setattrbycolumn(dest, self.dest_column, value) class BinaryVisitor(sql.ClauseVisitor): diff --git a/test/inheritance.py b/test/inheritance.py index 7e6510bb29..82d66e48c7 100644 --- a/test/inheritance.py +++ b/test/inheritance.py @@ -127,8 +127,10 @@ class InheritTest2(testbase.AssertMixin): return "Bar(%s)" % self.data Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper, properties = { - # TODO: use syncrules for this - 'id':[bar.c.bid, foo.c.id] + # the old way, you needed to explicitly set up a compound + # column like this. but now the mapper uses SyncRules to match up + # the parent/child inherited columns + #'id':[bar.c.bid, foo.c.id] }) Bar.mapper.add_property('foos', relation(Foo.mapper, foo_bar, primaryjoin=bar.c.bid==foo_bar.c.bar_id, secondaryjoin=foo_bar.c.foo_id==foo.c.id, lazy=False)) -- 2.47.2