From: Mike Bayer Date: Sat, 11 Feb 2006 23:29:02 +0000 (+0000) Subject: more hammering of defaults. ORM will properly execute defaults and post-fetch rows... X-Git-Tag: rel_0_1_0~26 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=adbd3ae29eefd2f3a34d8dd5166eceff91347557;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git more hammering of defaults. ORM will properly execute defaults and post-fetch rows that contain passive defaults --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 3b4ae64a70..3eaed0bc1d 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -522,11 +522,10 @@ class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator): def get_column_default_string(self, column): if isinstance(column.default, schema.PassiveDefault): - if not isinstance(column.default.arg, str): - arg = str(column.default.arg.compile(self.engine)) + if isinstance(column.default.arg, str): + return repr(column.default.arg) else: - arg = column.default.arg - return arg + return str(column.default.arg.compile(self.engine)) else: return None diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py index f6dd251cd6..f9c7aaf0b0 100644 --- a/lib/sqlalchemy/databases/information_schema.py +++ b/lib/sqlalchemy/databases/information_schema.py @@ -131,7 +131,7 @@ def reflecttable(engine, table, ischema_names, use_mysql=False): coltype = coltype(*args) colargs= [] if default is not None: - colargs.append(PassiveDefault(default)) + colargs.append(PassiveDefault(sql.text(default, escape=False))) table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs)) s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True) diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index e5e4ab063b..16ceb38735 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -508,7 +508,8 @@ class Mapper(object): # matching the bindparam we are creating below, i.e. "_" params[col.table.name + "_" + col.key] = self._getattrbycolumn(obj, col) else: - # doing an INSERT? if the primary key values are not populated, + # doing an INSERT, primary key col ? + # if the primary key values are not populated, # leave them out of the INSERT altogether, since PostGres doesn't want # them to be present for SERIAL to take effect. A SQLEngine that uses # explicit sequences will put them back in if they are needed @@ -529,9 +530,15 @@ class Mapper(object): params[col.key] = a[0] hasdata = True else: - # doing an INSERT ? add the attribute's value to the - # bind parameters - params[col.key] = self._getattrbycolumn(obj, col) + # doing an INSERT, non primary key col ? + # add the attribute's value to the + # bind parameters, unless its None and the column has a + # default. if its None and theres no default, we still might + # not want to put it in the col list but SQLIte doesnt seem to like that + # if theres no columns at all + value = self._getattrbycolumn(obj, col) + if col.default is None or value is not None: + params[col.key] = value if not isinsert: if hasdata: @@ -572,8 +579,8 @@ class Mapper(object): clause.clauses.append(p == self._getattrbycolumn(obj, p)) row = table.select(clause).execute().fetchone() for c in table.c: - if self._getattrbycolumn(obj, col) is None: - self._setattrbycolumn(obj, col, row[c]) + if self._getattrbycolumn(obj, c) is None: + self._setattrbycolumn(obj, c, row[c]) self.extension.after_insert(self, obj) def delete_obj(self, objects, uow): diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index c8a05470c1..d4ae77a4e9 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -638,7 +638,7 @@ class TextClause(ClauseElement): being specified as a bind parameter via the bindparam() method, since it provides more information about what it is, including an optional type, as well as providing comparison operations.""" - def __init__(self, text = "", engine=None, bindparams=None, typemap=None): + def __init__(self, text = "", engine=None, bindparams=None, typemap=None, escape=True): self.parens = False self._engine = engine self.id = id(self) @@ -650,8 +650,11 @@ class TextClause(ClauseElement): def repl(m): self.bindparams[m.group(1)] = bindparam(m.group(1)) return self.engine.bindtemplate % m.group(1) - - self.text = re.compile(r':([\w_]+)', re.S).sub(repl, text) + + if escape: + self.text = re.compile(r':([\w_]+)', re.S).sub(repl, text) + else: + self.text = text if bindparams is not None: for b in bindparams: self.bindparams[b.key] = b diff --git a/test/engines.py b/test/engines.py index f7bde7118a..7cc07dddaf 100644 --- a/test/engines.py +++ b/test/engines.py @@ -16,12 +16,21 @@ class EngineTest(PersistTest): use_function_defaults = testbase.db.engine.__module__.endswith('postgres') or testbase.db.engine.__module__.endswith('oracle') + use_string_defaults = use_function_defaults or testbase.db.engine.__module__.endswith('sqlite') + if use_function_defaults: defval = func.current_date() deftype = Date else: defval = "3" deftype = Integer + + if use_string_defaults: + deftype2 = String + defval2 = "im a default" + else: + deftype2 = Integer + defval2 = "15" users = Table('engine_users', testbase.db, Column('user_id', INT, primary_key = True), @@ -36,6 +45,8 @@ class EngineTest(PersistTest): Column('test7', String), Column('test8', Binary), Column('test_passivedefault', deftype, PassiveDefault(defval)), + Column('test_passivedefault2', Integer, PassiveDefault("5")), + Column('test_passivedefault3', deftype2, PassiveDefault(defval2)), Column('test9', Binary(100)), mysql_engine='InnoDB' ) diff --git a/test/objectstore.py b/test/objectstore.py index 223c4aba20..6d5a6ea09f 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -103,7 +103,57 @@ class PKTest(AssertMixin): objectstore.clear() e2 = Entry.mapper.get(e.multi_id, 2) self.assert_(e is not e2 and e._instance_key == e2._instance_key) + +class DefaultTest(AssertMixin): + def setUpAll(self): + #db.echo = 'debug' + use_string_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle') or db.engine.__module__.endswith('sqlite') + + if use_string_defaults: + hohotype = String + self.hohoval = "im hoho" + self.althohoval = "im different hoho" + else: + hohotype = Integer + self.hohoval = 9 + self.althohoval = 15 + self.table = Table('default_test', db, + Column('id', Integer, Sequence("dt_seq", optional=True), primary_key=True), + Column('hoho', hohotype, PassiveDefault(str(self.hohoval))), + Column('counter', Integer, PassiveDefault("7")), + Column('foober', String, default="im foober") + ) + self.table.create() + def tearDownAll(self): + self.table.drop() + def testbasic(self): + class Hoho(object):pass + assign_mapper(Hoho, self.table) + h1 = Hoho(hoho=self.althohoval) + h2 = Hoho(counter=12) + h3 = Hoho(hoho=self.althohoval, counter=12) + h4 = Hoho() + h5 = Hoho(foober='im the new foober') + objectstore.commit() + self.assert_(h1.hoho==self.althohoval) + self.assert_(h3.hoho==self.althohoval) + self.assert_(h2.hoho==h4.hoho==h5.hoho==self.hohoval) + self.assert_(h3.counter == h2.counter == 12) + self.assert_(h1.counter == h4.counter==h5.counter==7) + self.assert_(h2.foober == h3.foober == h4.foober == 'im foober') + self.assert_(h5.foober=='im the new foober') + objectstore.clear() + l = Hoho.mapper.select() + (h1, h2, h3, h4, h5) = l + self.assert_(h1.hoho==self.althohoval) + self.assert_(h3.hoho==self.althohoval) + self.assert_(h2.hoho==h4.hoho==h5.hoho==self.hohoval) + self.assert_(h3.counter == h2.counter == 12) + self.assert_(h1.counter == h4.counter==h5.counter==7) + self.assert_(h2.foober == h3.foober == h4.foober == 'im foober') + self.assert_(h5.foober=='im the new foober') + class SaveTest(AssertMixin): def setUpAll(self): diff --git a/test/query.py b/test/query.py index 6c4e017cd0..5841299c26 100644 --- a/test/query.py +++ b/test/query.py @@ -54,7 +54,7 @@ class QueryTest(PersistTest): f = select([func.count(1)], engine=db).scalar() if use_function_defaults: def1 = func.current_date() - def2 = "current_date" + def2 = text("current_date") deftype = Date ts = select([func.current_date()], engine=db).scalar() else: @@ -72,6 +72,7 @@ class QueryTest(PersistTest): t.create() try: t.insert().execute() + self.assert_(t.engine.lastrow_has_defaults()) t.insert().execute() t.insert().execute()