From: Mike Bayer Date: Fri, 31 Mar 2006 04:27:05 +0000 (+0000) Subject: starting to refactor mapper slightly, adding entity_name, version_id_col, allowing... X-Git-Tag: rel_0_1_6~48 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=90ffb177ed88cac43d4c3cbdc568d0d0a93fd579;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git starting to refactor mapper slightly, adding entity_name, version_id_col, allowing keywords in mapper.options() --- diff --git a/doc/build/content/adv_datamapping.myt b/doc/build/content/adv_datamapping.myt index 1f98176a81..7a8fefd0c6 100644 --- a/doc/build/content/adv_datamapping.myt +++ b/doc/build/content/adv_datamapping.myt @@ -314,7 +314,14 @@ WHERE rowcount.user_id = users.user_id ORDER BY users.oid, addresses.oid # set the referenced mapper 'photos' to defer its loading of the column 'imagedata' m = book_mapper.options(defer('photos.imagedata')) - +

Options can also take a limited set of keyword arguments which will be applied to a new mapper. For example, to create a mapper that refreshes all objects loaded each time:

+ <&|formatting.myt:code&> + m2 = mapper.options(always_refresh=True) + +

Or, a mapper with different ordering:

+ <&|formatting.myt:code&> + m2 = mapper.options(order_by=[newcol]) + @@ -557,7 +564,16 @@ WHERE rowcount.user_id = users.user_id ORDER BY users.oid, addresses.oid address = r[1] - +<&|doclib.myt:item, name="arguments", description="Mapper Arguments" &> +

Other arguments not covered above include:

+ + <&|doclib.myt:item, name="extending", description="Extending Mapper" &>

Mappers can have functionality augmented or replaced at many points in its execution via the usage of the MapperExtension class. This class is just a series of "hooks" where various functionality takes place. An application can make its own MapperExtension objects, overriding only the methods it needs. <&|formatting.myt:code&> diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index f8faea8555..7e12459c53 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -40,7 +40,9 @@ class Mapper(object): extension = None, order_by = False, allow_column_override = False, + entity_name = None, always_refresh = False, + version_id_col = None, **kwargs): if primarytable is not None: @@ -55,6 +57,8 @@ class Mapper(object): self.order_by = order_by self._options = {} self.always_refresh = always_refresh + self.entity_name = entity_name + self.version_id_col = version_id_col if not issubclass(class_, object): raise ArgumentError("Class '%s' is not a new-style class" % class_.__name__) @@ -85,7 +89,7 @@ class Mapper(object): # stricter set of tables to create "sync rules" by,based on the immediate # inherited table, rather than all inherited tables self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY) - self._synchronizer.compile(self.table.onclause, util.HashSet([inherits.noninherited_table]), TableFinder(table)) + self._synchronizer.compile(self.table.onclause, util.HashSet([inherits.noninherited_table]), mapperutil.TableFinder(table)) # the old rule #self._synchronizer.compile(self.table.onclause, inherits.tables, TableFinder(table)) else: @@ -100,7 +104,7 @@ class Mapper(object): # locate all tables contained within the "table" passed in, which # may be a join or other construct - self.tables = TableFinder(self.table) + self.tables = mapperutil.TableFinder(self.table) # determine primary key columns, either passed in, or get them from our set of tables self.pks_by_table = {} @@ -350,9 +354,10 @@ class Mapper(object): compiling or executing it""" return self._compile(whereclause, **options) - def copy(self): + def copy(self, **kwargs): mapper = Mapper.__new__(Mapper) mapper.__dict__.update(self.__dict__) + mapper.__dict__.update(kwargs) mapper.props = self.props.copy() return mapper @@ -374,7 +379,7 @@ class Mapper(object): return callit return Proxy() - def options(self, *options): + def options(self, *options, **kwargs): """uses this mapper as a prototype for a new mapper with different behavior. *options is a list of options directives, which include eagerload(), lazyload(), and noload()""" @@ -382,7 +387,7 @@ class Mapper(object): try: return self._options[optkey] except KeyError: - mapper = self.copy() + mapper = self.copy(**kwargs) for option in options: option.process(mapper) self._options[optkey] = mapper @@ -610,7 +615,13 @@ class Mapper(object): self.extension.before_update(self, obj) hasdata = False for col in table.columns: - if self.pks_by_table[table].contains(col): + if col is self.version_id_col: + if not isinsert: + params[col._label] = self._getattrbycolumn(obj, col) + params[col.key] = params[col._label] + 1 + else: + params[col.key] = 1 + elif self.pks_by_table[table].contains(col): # column is a primary key ? if not isinsert: # doing an UPDATE? put primary key values as "WHERE" parameters @@ -664,6 +675,8 @@ class Mapper(object): clause = sql.and_() for col in self.pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col._label)) + if self.version_id_col is not None: + clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col._label)) statement = table.update(clause) rows = 0 for rec in update: @@ -729,11 +742,15 @@ class Mapper(object): delete.append(params) for col in self.pks_by_table[table]: params[col.key] = self._getattrbycolumn(obj, col) + if self.version_id_col is not None: + params[self.version_id_col.key] = self._getattrbycolumn(obj, self.version_id_col) self.extension.before_delete(self, obj) if len(delete): clause = sql.and_() for col in self.pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col.key)) + if self.version_id_col is not None: + clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col.key)) statement = table.delete(clause) c = statement.execute(*delete) if table.engine.supports_sane_rowcount() and c.rowcount != len(delete): @@ -1036,28 +1053,6 @@ class MapperExtension(object): if self.next is not None: self.next.before_delete(mapper, instance) -class TableFinder(sql.ClauseVisitor): - """given a Clause, locates all the Tables within it into a list.""" - def __init__(self, table, check_columns=False): - self.tables = [] - self.check_columns = check_columns - if table is not None: - table.accept_visitor(self) - def visit_table(self, table): - self.tables.append(table) - def __len__(self): - return len(self.tables) - def __getitem__(self, i): - return self.tables[i] - def __iter__(self): - return iter(self.tables) - def __contains__(self, obj): - return obj in self.tables - def __add__(self, obj): - return self.tables + list(obj) - def visit_column(self, column): - if self.check_columns: - column.table.accept_visitor(self) def hash_key(obj): if obj is None: