From: Mike Bayer Date: Tue, 6 Sep 2005 01:10:23 +0000 (+0000) Subject: (no commit message) X-Git-Tag: rel_0_1_0~793 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8b78a8ca306a79a70717aae7e6d830d2f9ddc9d4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git --- diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 74b72dd1bf..969a93a238 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -35,8 +35,8 @@ def relation_loader(mapper, secondary = None, primaryjoin = None, secondaryjoin else: return EagerLoader(mapper, secondary, primaryjoin, secondaryjoin, **options) -def relation_mapper(class_, selectable, secondary = None, primaryjoin = None, secondaryjoin = None, table = None, properties = None, lazy = True, **options): - return relation_loader(mapper(class_, selectable, table = table, properties = properties, isroot = False, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, **options) +def relation_mapper(class_, selectable, secondary = None, primaryjoin = None, secondaryjoin = None, table = None, properties = None, lazy = True, uselist = True, **options): + return relation_loader(mapper(class_, selectable, table = table, properties = properties, isroot = False, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, uselist = uselist, **options) _mappers = {} def mapper(*args, **params): @@ -54,9 +54,8 @@ def lazyload(name): return EagerLazySwitcher(name, toeager = False) class Mapper(object): - def __init__(self, class_, selectable, table = None, scope = "thread", properties = None, use_smart_properties = True, isroot = True, echo = None): + def __init__(self, class_, selectable, table = None, scope = "thread", properties = None, isroot = True, echo = None): self.class_ = class_ - self.use_smart_properties = use_smart_properties self.scope = scope self.selectable = selectable tf = TableFinder() @@ -124,7 +123,6 @@ class Mapper(object): self.table, self.properties, self.scope, - self.use_smart_properties, self.echo ) @@ -385,32 +383,26 @@ class ColumnProperty(MapperProperty): def init(self, key, parent, root): self.key = key - if root.use_smart_properties: - self.use_smart = True - if not hasattr(parent.class_, key): - setattr(parent.class_, key, SmartProperty(key).property()) - else: - self.use_smart = False + if not hasattr(parent.class_, key): + setattr(parent.class_, key, SmartProperty(key).property()) def execute(self, instance, row, identitykey, localmap, isduplicate): if not isduplicate: - if self.use_smart: - clean_setattr(instance, self.key, row[self.columns[0].label]) - else: - setattr(instance, self.key, row[self.columns[0].label]) + clean_setattr(instance, self.key, row[self.columns[0].label]) class PropertyLoader(MapperProperty): """describes an object property that holds a list of items that correspond to a related database table.""" - def __init__(self, mapper, secondary, primaryjoin, secondaryjoin): + def __init__(self, mapper, secondary, primaryjoin, secondaryjoin, uselist = True): + self.uselist = uselist self.mapper = mapper self.target = self.mapper.selectable self.secondary = secondary self.primaryjoin = primaryjoin self.secondaryjoin = secondaryjoin - self._hash_key = "%s(%s, %s, %s, %s)" % (self.__class__.__name__, hash_key(mapper), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin)) + self._hash_key = "%s(%s, %s, %s, %s, uselist=%s)" % (self.__class__.__name__, hash_key(mapper), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), repr(self.uselist)) def hash_key(self): return self._hash_key @@ -427,6 +419,9 @@ class PropertyLoader(MapperProperty): else: if self.primaryjoin is None: self.primaryjoin = match_primaries(parent.selectable, self.target) + + if not self.uselist and not hasattr(parent.class_, key): + setattr(parent.class_, key, SmartProperty(key).property(usehistory = True)) def save(self, obj, traverse): # saves child objects @@ -436,11 +431,19 @@ class PropertyLoader(MapperProperty): secondary_insert = [] setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, self.secondary, obj) - childlist = getattr(obj, self.key) - if not isinstance(childlist, util.HistoryArraySet): - childlist = util.HistoryArraySet(childlist) - clean_setattr(obj, self.key, childlist) - + + + if self.uselist: + childlist = getattr(obj, self.key) + if not isinstance(childlist, util.HistoryArraySet): + childlist = util.HistoryArraySet(childlist) + clean_setattr(obj, self.key, childlist) + else: + childlist = GetPropHistory() + # this is a nasty trick to communicate with a property() + setattr(obj, self.key, childlist) + childlist = childlist.history + for child in childlist.deleted_items(): setter.child = child setter.associationrow = {} @@ -452,6 +455,7 @@ class PropertyLoader(MapperProperty): secondary_delete.append(setter.associationrow) for child in childlist.added_items(): + print "yup " + repr(child) setter.child = child setter.associationrow = {} self.primaryjoin.accept_visitor(setter) @@ -515,8 +519,16 @@ class LazyLoadInstance(object): # quickly, so an object with a lazyloader still cant really be serialized self.mapper = lazyloader.mapper self.lazywhere = lazyloader.lazywhere + self.uselist = lazyloader.uselist def __call__(self): - return self.mapper.select(self.lazywhere, **self.params) + result = self.mapper.select(self.lazywhere, **self.params) + if self.uselist: + return result + else: + if len(result): + return result[0] + else: + return None class EagerLoader(PropertyLoader): """loads related objects inline with a parent query.""" @@ -561,14 +573,19 @@ class EagerLoader(PropertyLoader): def execute(self, instance, row, identitykey, localmap, isduplicate): """receive a row. tell our mapper to look for a new object instance in the row, and attach it to a list on the parent instance.""" - if not isduplicate: + if not self.uselist: + result_list = [] + elif not isduplicate: result_list = util.HistoryArraySet() clean_setattr(instance, self.key, result_list) else: result_list = getattr(instance, self.key) self.mapper._instance(row, localmap, result_list) - + + if not self.uselist: + clean_setattr(instance, self.key, result_list[0]) + class LazyRow(MapperProperty): """TODO: this will lazy-load additional properties of an object from a secondary table.""" def __init__(self, table, whereclause, **options): @@ -679,11 +696,29 @@ class SmartProperty(object): def __init__(self, key): self.key = key - def property(self): + def get_history(self, obj): + if not hasattr(obj, '_history'): + obj._history = {} + if not obj._history.has_key(self.key): + obj._history[self.key] = util.PropHistory(obj.__dict__.get(self.key, None)) + return obj._history[self.key] + + def property(self, usehistory = False): + # TODO: all the history/dirty crap here is temporary, should communicate with a + # thread-local unit of work def set_prop(s, value): + if usehistory: + hist = self.get_history(s) + if isinstance(value, GetPropHistory): + value.history = hist + return + hist.setattr(value, s.__dict__.get(self.key, None)) s.__dict__[self.key] = value s.dirty = True def del_prop(s): + if usehistory: + hist = self.get_history(s) + hist.delattr(value) del s.__dict__[self.key] s.dirty = True def get_prop(s): @@ -696,6 +731,8 @@ class SmartProperty(object): return s.__dict__[self.key] return property(get_prop, set_prop, del_prop) +class GetPropHistory:pass + identity_map = util.ScopedRegistry(lambda: {}) def clean_setattr(object, key, value): @@ -714,17 +751,16 @@ def hash_key(obj): else: return obj.hash_key() -def mapper_hash_key(class_, selectable, table = None, properties = None, scope = "thread", use_smart_properties = True, isroot = True, echo = None): +def mapper_hash_key(class_, selectable, table = None, properties = None, scope = "thread", isroot = True, echo = None): if properties is None: properties = {} return ( - "Mapper(%s, %s, table=%s, properties=%s, scope=%s, use_smart_properties=%s, echo=%s)" % ( + "Mapper(%s, %s, table=%s, properties=%s, scope=%s, echo=%s)" % ( repr(class_), hash_key(selectable), hash_key(table), repr(dict([(k, hash_key(p)) for k,p in properties.iteritems()])), scope, - repr(use_smart_properties), repr(echo) ) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index e6802bbbb0..7013e6e6cf 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -203,6 +203,36 @@ class HistoryArraySet(UserList.UserList): raise NotImplementedError() def __iadd__(self, other): raise NotImplementedError() + +class PropHistory(object): + def __init__(self, current): + self.added = None + self.current = current + self.deleted = None + def setattr(self, value, current): + self.current = None + self.deleted = current + self.added = value + def delattr(self, current): + self.deleted = current + def clear_history(self): + if self.added is not None: + self.current = self.added + def added_items(self): + if self.added is not None: + return [self.added] + else: + return [] + def deleted_items(self): + if self.deleted is not None: + return [self.deleted] + else: + return [] + def unchanged_items(self): + if self.current is not None: + return [self.current] + else: + return [] class ScopedRegistry(object): def __init__(self, createfunc): diff --git a/test/mapper.py b/test/mapper.py index ba1a20e90a..673884573b 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -23,7 +23,7 @@ Closed Orderss %s class Address(object): def __repr__(self): - return "Address: " + repr(self.address_id) + " " + repr(self.user_id) + " " + repr(self.email_address) + return "Address: " + repr(getattr(self, 'address_id', None)) + " " + repr(getattr(self, 'user_id', None)) + " " + repr(self.email_address) class Order(object): def __repr__(self): @@ -310,6 +310,16 @@ class SaveTest(AssertMixin): u = m.select(users.c.user_id==u.foo_id)[0] print repr(u.__dict__) + def testonetoone(self): + m = mapper(User, users, properties = dict( + address = relation(Address, addresses, lazy = True, uselist = False) + )) + u = User() + u.user_name = 'one2onetester' + u.address = Address() + u.address.email_address = 'myonlyaddress@foo.com' + m.save(u) + def testonetomany(self): """test basic save of one to many.""" m = mapper(User, users, properties = dict(