From: Mike Bayer Date: Thu, 15 Dec 2005 04:25:59 +0000 (+0000) Subject: rethinking sequences model to allow any default values X-Git-Tag: rel_0_1_0~239 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=893ce7768262debc7c347afdb37ca2a22d1ef9a7;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git rethinking sequences model to allow any default values --- diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index cad60b9697..977b4c427e 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -241,6 +241,41 @@ class SQLEngine(schema.SchemaEngine): self.do_commit(self.context.transaction) self.context.transaction = None self.context.tcount = None + + + def _process_sequences(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs): + if compiled is None: return + if getattr(compiled, "isinsert", False): + if isinstance(parameters, list): + plist = parameters + else: + plist = [parameters] + # inserts are usually one at a time. but if we got a list of parameters, + # it will calculate last_inserted_ids for just the last row in the list. + # TODO: why not make last_inserted_ids a 2D array since we have to explicitly sequence + # it or post-select anyway + for param in plist: + last_inserted_ids = [] + need_lastrowid=False + for c in compiled.statement.table.c: + if not param.has_key(c.key) or param[c.key] is None: + if c.sequence is not None: + newid = self.exec_sequence(c.sequence) + else: + newid = None + + if newid is not None: + param[c.key] = newid + if c.primary_key: + last_inserted_ids.append(param[c.key]) + elif c.primary_key: + need_lastrowid = True + elif c.primary_key: + last_inserted_ids.append(param[c.key]) + if need_lastrowid: + self.context.last_inserted_ids = None + else: + self.context.last_inserted_ids = last_inserted_ids def pre_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs): pass @@ -287,7 +322,8 @@ class SQLEngine(schema.SchemaEngine): try: self.pre_exec(connection, c, statement, parameters, echo = echo, **kwargs) - + #self._process_sequences(connection, c, statement, parameters, echo = echo, **kwargs) + if echo is True or self.echo is not False: self.log(statement) self.log(repr(parameters)) diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index 5fdf0e35b0..cfa373e0ae 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -459,6 +459,7 @@ class Mapper(object): if hasattr(obj, "_instance_key"): update.append(params) else: + self.extension.before_insert(self, obj) insert.append((obj, params)) uow.register_saved_object(obj) if len(update): @@ -476,6 +477,11 @@ class Mapper(object): for rec in insert: (obj, params) = rec statement.execute(**params) + # TODO: the engine is going to now do defaults for non-primarykey columns as well. + # have the engine store a dictionary of all column/generated values and set them + # all up. also might want to have the last_inserted_ids that does a select actually + # go and get everything that was generated by DB-level defaults, not just primary key + # columns. primary_key = table.engine.last_inserted_ids() if primary_key is not None: i = 0 @@ -676,8 +682,8 @@ class ExtensionOption(MapperOption): def __init__(self, ext): self.ext = ext def process(self, mapper): - ext.next = mapper.extension - mapper.extension = ext + self.ext.next = mapper.extension + mapper.extension = self.ext class MapperExtension(object): def __init__(self): @@ -692,6 +698,9 @@ class MapperExtension(object): return True else: return self.next.append_result(mapper, row, imap, result, instance, isnew, populate_existing) + def before_insert(self, mapper, instance): + if self.next is not None: + self.next.before_insert(mapper, instance) def after_insert(self, mapper, instance): if self.next is not None: self.next.after_insert(mapper, instance) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 789bb6ed8c..60c42c25ad 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -288,8 +288,9 @@ class ForeignKey(SchemaItem): class Sequence(SchemaItem): """represents a sequence, which applies to Oracle and Postgres databases.""" - def __init__(self, name, start = None, increment = None, optional=False): + def __init__(self, name, func = None, start = None, increment = None, optional=False): self.name = name + self.func = func self.start = start self.increment = increment self.optional=optional