From 40964f68a143ab211bfd903dcc6733bf1c77906a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 17 Dec 2005 02:49:47 +0000 Subject: [PATCH] refactoring of execution path, defaults, and treatment of different paramstyles --- lib/sqlalchemy/ansisql.py | 111 ++++++++++++++--------- lib/sqlalchemy/databases/mysql.py | 3 +- lib/sqlalchemy/databases/oracle.py | 3 +- lib/sqlalchemy/databases/postgres.py | 20 ++--- lib/sqlalchemy/databases/sqlite.py | 3 +- lib/sqlalchemy/engine.py | 129 +++++++++++++++++++-------- lib/sqlalchemy/schema.py | 1 + lib/sqlalchemy/sql.py | 36 ++++++-- test/query.py | 17 ++++ test/select.py | 4 +- test/tables.py | 2 +- test/testbase.py | 41 ++------- 12 files changed, 227 insertions(+), 143 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index cd1d3a0b0a..e4bcdd0775 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -37,8 +37,8 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine): def schemadropper(self, proxy, **params): return ANSISchemaDropper(proxy, **params) - def compiler(self, statement, bindparams, **kwargs): - return ANSICompiler(self, statement, bindparams, **kwargs) + def compiler(self, statement, parameters, **kwargs): + return ANSICompiler(self, statement, parameters, **kwargs) def connect_args(self): return ([],{}) @@ -47,8 +47,20 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine): return None class ANSICompiler(sql.Compiled): - def __init__(self, engine, statement, bindparams, typemap=None, paramstyle=None,**kwargs): - sql.Compiled.__init__(self, engine, statement, bindparams) + """default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings.""" + def __init__(self, engine, statement, parameters=None, typemap=None, **kwargs): + """constructs a new ANSICompiler object. + + engine - SQLEngine to compile against + + statement - ClauseElement to be compiled + + parameters - optional dictionary indicating a set of bind parameters + specified with this Compiled object. These parameters are the "default" + key/value pairs when the Compiled is executed, and also may affect the + actual compilation, as in the case of an INSERT where the actual columns + inserted will correspond to the keys present in the parameters.""" + sql.Compiled.__init__(self, engine, statement, parameters) self.binds = {} self.froms = {} self.wheres = {} @@ -57,37 +69,18 @@ class ANSICompiler(sql.Compiled): self.typemap = typemap or {} self.isinsert = False - if paramstyle is None: - db = self.engine.dbapi() - if db is not None: - paramstyle = db.paramstyle - else: - paramstyle = 'named' - - if paramstyle == 'named': - self.bindtemplate = ':%s' - self.positional=False - elif paramstyle =='pyformat': - self.bindtemplate = "%%(%s)s" - self.positional=False - else: - # for positional, use pyformat until the end - self.bindtemplate = "%%(%s)s" - self.positional=True - self.paramstyle=paramstyle - def after_compile(self): - if self.positional: + if self.engine.positional: self.positiontup = [] match = r'%\(([\w_]+)\)s' params = re.finditer(match, self.strings[self.statement]) for p in params: self.positiontup.append(p.group(1)) - if self.paramstyle=='qmark': + if self.engine.paramstyle=='qmark': self.strings[self.statement] = re.sub(match, '?', self.strings[self.statement]) - elif self.paramstyle=='format': + elif self.engine.paramstyle=='format': self.strings[self.statement] = re.sub(match, '%s', self.strings[self.statement]) - elif self.paramstyle=='numeric': + elif self.engine.paramstyle=='numeric': i = 0 def getnum(x): i += 1 @@ -116,14 +109,22 @@ class ANSICompiler(sql.Compiled): for an executemany style of call, this method should be called for each element in the list of parameter groups that will ultimately be executed. """ - d = {} - if self.bindparams is not None: - bindparams = self.bindparams.copy() + if self.parameters is not None: + bindparams = self.parameters.copy() else: bindparams = {} bindparams.update(params) - # TODO: cant we make "d" an ordereddict and add params in - # positional order + + if self.engine.positional: + d = OrderedDict() + for k in self.positiontup: + b = self.binds[k] + d[k] = b.typeprocess(b.value) + else: + d = {} + for b in self.binds.values(): + d[b.key] = b.typeprocess(b.value) + for key, value in bindparams.iteritems(): try: b = self.binds[key] @@ -131,11 +132,9 @@ class ANSICompiler(sql.Compiled): continue d[b.key] = b.typeprocess(value) - for b in self.binds.values(): - d.setdefault(b.key, b.typeprocess(b.value)) - - if self.positional: - return [d[key] for key in self.positiontup] + return d + if self.engine.positional: + return d.values() else: return d @@ -145,7 +144,8 @@ class ANSICompiler(sql.Compiled): same dictionary. For a positional paramstyle, the given parameters are assumed to be in list format and are converted back to a dictionary. """ - if self.positional: +# return parameters + if self.engine.positional: p = {} for i in range(0, len(self.positiontup)): p[self.positiontup[i]] = parameters[i] @@ -237,7 +237,7 @@ class ANSICompiler(sql.Compiled): self.strings[bindparam] = self.bindparam_string(key) def bindparam_string(self, name): - return self.bindtemplate % name + return self.engine.bindtemplate % name def visit_alias(self, alias): self.froms[alias] = self.get_from_text(alias.selectable) + " AS " + alias.name @@ -265,7 +265,7 @@ class ANSICompiler(sql.Compiled): text = "SELECT " if select.distinct: text += "DISTINCT " - text += collist + " \nFROM " + text += collist whereclause = select.whereclause @@ -282,8 +282,10 @@ class ANSICompiler(sql.Compiled): t = self.get_from_text(f) if t is not None: froms.append(t) - - text += string.join(froms, ', ') + + if len(froms): + text += " \nFROM " + text += string.join(froms, ', ') if whereclause is not None: t = self.get_str(whereclause) @@ -333,10 +335,31 @@ class ANSICompiler(sql.Compiled): self.froms[join] = (self.get_from_text(join.left) + " JOIN " + righttext + " ON " + self.get_str(join.onclause)) self.strings[join] = self.froms[join] + + def visit_insert_column_default(self, column, default): + """called when visiting an Insert statement, for each column in the table that + contains a ColumnDefault object.""" + self.parameters.setdefault(column.key, None) + + def visit_insert_sequence(self, column, sequence): + """called when visiting an Insert statement, for each column in the table that + contains a Sequence object.""" + pass def visit_insert(self, insert_stmt): + # set up a call for the defaults and sequences inside the table + class DefaultVisitor(schema.SchemaVisitor): + def visit_column_default(s, cd): + self.visit_insert_column_default(c, cd) + def visit_sequence(s, seq): + self.visit_insert_sequence(c, seq) + vis = DefaultVisitor() + for c in insert_stmt.table.c: + if self.parameters.get(c.key, None) is None and c.default is not None: + c.default.accept_visitor(vis) + self.isinsert = True - colparams = insert_stmt.get_colparams(self.bindparams) + colparams = insert_stmt.get_colparams(self.parameters) for c in colparams: b = c[1] self.binds[b.key] = b @@ -348,7 +371,7 @@ class ANSICompiler(sql.Compiled): self.strings[insert_stmt] = text def visit_update(self, update_stmt): - colparams = update_stmt.get_colparams(self.bindparams) + colparams = update_stmt.get_colparams(self.parameters) def create_param(p): if isinstance(p, sql.BindParamClause): self.binds[p.key] = p diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 191a57ba60..96bacf2101 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -140,8 +140,7 @@ class MySQLEngine(ansisql.ANSISQLEngine): def last_inserted_ids(self): return self.context.last_inserted_ids - def post_exec(self, proxy, statement, parameters, compiled = None, **kwargs): - if compiled is None: return + def post_exec(self, proxy, compiled, parameters, **kwargs): if getattr(compiled, "isinsert", False): self.context.last_inserted_ids = [proxy().lastrowid] diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index bc30a4937d..163d387bc4 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -118,8 +118,7 @@ class OracleSQLEngine(ansisql.ANSISQLEngine): def last_inserted_ids(self): return self.context.last_inserted_ids - def pre_exec(self, proxy, statement, parameters, compiled=None, **kwargs): - if compiled is None: return + def pre_exec(self, proxy, compiled, parameters, **kwargs): # this is just an assertion that all the primary key columns in an insert statement # have a value set up, or have a default generator ready to go if getattr(compiled, "isinsert", False): diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 29590da0ac..0ec84dec4d 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -153,7 +153,7 @@ class PGSQLEngine(ansisql.ANSISQLEngine): return PGSchemaDropper(proxy, **params) def defaultrunner(self, proxy): - return PGDefaultRunner(proxy) + return PGDefaultRunner(self, proxy) def get_default_schema_name(self): if not hasattr(self, '_default_schema_name'): @@ -166,8 +166,7 @@ class PGSQLEngine(ansisql.ANSISQLEngine): def pre_exec(self, proxy, statement, parameters, **kwargs): return - def post_exec(self, proxy, statement, parameters, compiled = None, **kwargs): - if compiled is None: return + def post_exec(self, proxy, compiled, parameters, **kwargs): if getattr(compiled, "isinsert", False) and self.context.last_inserted_ids is None: table = compiled.statement.table cursor = proxy() @@ -200,15 +199,10 @@ class PGSQLEngine(ansisql.ANSISQLEngine): ischema.reflecttable(self, table, ischema_names) class PGCompiler(ansisql.ANSICompiler): - def visit_insert(self, insert): - """inserts are required to have the primary keys be explicitly present. - mapper will by default not put them in the insert statement to comply - with autoincrement fields that require they not be present. so, - put them all in for columns where sequence usage is defined.""" - for c in insert.table.primary_key: - if self.bindparams.get(c.key, None) is None and c.default is not None and not c.default.optional: - self.bindparams[c.key] = None - return ansisql.ANSICompiler.visit_insert(self, insert) + + def visit_insert_sequence(self, column, sequence): + if self.parameters.get(column.key, None) is None and not sequence.optional: + self.parameters[column.key] = None def limit_clause(self, select): text = "" @@ -223,7 +217,7 @@ class PGCompiler(ansisql.ANSICompiler): class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name - if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or column.default.optional): + if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): colspec += " SERIAL" else: colspec += " " + column.type.get_col_spec() diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index a70f65d2a2..e743d14e0b 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -108,8 +108,7 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): params['poolclass'] = sqlalchemy.pool.SingletonThreadPool ansisql.ANSISQLEngine.__init__(self, **params) - def post_exec(self, proxy, statement, parameters, compiled = None, **kwargs): - if compiled is None: return + def post_exec(self, proxy, compiled, parameters, **kwargs): if getattr(compiled, "isinsert", False): self.context.last_inserted_ids = [proxy().lastrowid] diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 22ccbd11c0..81d72b17b2 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -78,17 +78,21 @@ class SchemaIterator(schema.SchemaVisitor): self.buffer.truncate(0) class DefaultRunner(schema.SchemaVisitor): - def __init__(self, proxy): + def __init__(self, engine, proxy): self.proxy = proxy + self.engine = engine def visit_sequence(self, seq): """sequences are not supported by default""" return None + def exec_default_sql(self, default): + c = sql.select([default.arg], engine=self.engine).compile() + return self.proxy(str(c), c.get_params()).fetchone()[0] + def visit_column_default(self, default): - if isinstance(default.arg, ClauseElement): - c = default.arg.compile() - return proxy.execute(str(c), c.get_params()) + if isinstance(default.arg, sql.ClauseElement): + return self.exec_default_sql(default) elif callable(default.arg): return default.arg() else: @@ -113,11 +117,29 @@ class SQLEngine(schema.SchemaEngine): self.context = util.ThreadLocal(raiseerror=False) self.tables = {} self.notes = {} + self._figure_paramstyle() if logger is None: self.logger = sys.stdout else: self.logger = logger - + + def _figure_paramstyle(self): + db = self.dbapi() + if db is not None: + self.paramstyle = db.paramstyle + else: + self.paramstyle = 'named' + + if self.paramstyle == 'named': + self.bindtemplate = ':%s' + self.positional=False + elif self.paramstyle =='pyformat': + self.bindtemplate = "%%(%s)s" + self.positional=False + else: + # for positional, use pyformat until the end + self.bindtemplate = "%%(%s)s" + self.positional=True def type_descriptor(self, typeobj): if type(typeobj) is type: @@ -131,9 +153,9 @@ class SQLEngine(schema.SchemaEngine): raise NotImplementedError() def defaultrunner(self, proxy): - return DefaultRunner(proxy) + return DefaultRunner(self, proxy) - def compiler(self, statement, bindparams): + def compiler(self, statement, parameters): raise NotImplementedError() def rowid_column_name(self): @@ -152,11 +174,11 @@ class SQLEngine(schema.SchemaEngine): """drops a table given a schema.Table object.""" table.accept_visitor(self.schemadropper(self.proxy(), **params)) - def compile(self, statement, bindparams, **kwargs): + def compile(self, statement, parameters, **kwargs): """given a sql.ClauseElement statement plus optional bind parameters, creates a new instance of this engine's SQLCompiler, compiles the ClauseElement, and returns the newly compiled object.""" - compiler = self.compiler(statement, bindparams, **kwargs) + compiler = self.compiler(statement, parameters, **kwargs) statement.accept_visitor(compiler) compiler.after_compile() return compiler @@ -263,26 +285,15 @@ class SQLEngine(schema.SchemaEngine): self.context.transaction = None self.context.tcount = None - def _process_defaults(self, proxy, statement, parameters, compiled=None, **kwargs): + def _process_defaults(self, proxy, compiled, parameters, **kwargs): if compiled is None: return if getattr(compiled, "isinsert", False): - # TODO: this sucks. we have to get "parameters" to be a more standardized object - if isinstance(parameters, list) and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)): + if isinstance(parameters, list): plist = parameters else: plist = [parameters] - # inserts are usually one at a time. but if we got a list of parameters, - # it will calculate last_inserted_ids for just the last row in the list. - # TODO: why not make last_inserted_ids a 2D array since we have to explicitly sequence - # it or post-select anyway drunner = self.defaultrunner(proxy) for param in plist: - # the parameters might be positional, so convert them - # to a dictionary - # TODO: this is stupid. or, is this stupid ? - # any way we can just have an OrderedDict so we have the - # dictionary + postional version each time ? - param = compiled.get_named_params(param) last_inserted_ids = [] need_lastrowid=False for c in compiled.statement.table.c: @@ -306,18 +317,18 @@ class SQLEngine(schema.SchemaEngine): self.context.last_inserted_ids = last_inserted_ids - def pre_exec(self, proxy, statement, parameters, **kwargs): + def pre_exec(self, proxy, compiled, parameters, **kwargs): pass - def post_exec(self, proxy, statement, parameters, **kwargs): + def post_exec(self, proxy, compiled, parameters, **kwargs): pass - def execute(self, statement, parameters, connection=None, cursor=None, echo = None, typemap = None, commit=False, **kwargs): + def execute_compiled(self, compiled, parameters, connection=None, cursor=None, echo=None, **kwargs): """executes the given string-based SQL statement with the given parameters. The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending on the paramstyle of the DBAPI. - + If the current thread has specified a transaction begin() for this engine, the statement will be executed in the context of the current transactional connection. Otherwise, a commit() will be performed immediately after execution, since the local @@ -352,6 +363,62 @@ class SQLEngine(schema.SchemaEngine): def proxy(statement=None, parameters=None): if statement is None: return cursor + + executemany = parameters is not None and isinstance(parameters, list) + + if self.positional: + if executemany: + parameters = [p.values() for p in parameters] + else: + parameters = parameters.values() + + self.execute(statement, parameters, connection=connection, cursor=cursor) + return cursor + + self.pre_exec(proxy, compiled, parameters, **kwargs) + self._process_defaults(proxy, compiled, parameters, **kwargs) + proxy(str(compiled), parameters) + self.post_exec(proxy, compiled, parameters, **kwargs) + return ResultProxy(cursor, self, typemap=compiled.typemap) + + def execute(self, statement, parameters, connection=None, cursor=None, echo = None, typemap = None, commit=False, **kwargs): + """executes the given string-based SQL statement with the given parameters. + + The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending + on the paramstyle of the DBAPI. + + If the current thread has specified a transaction begin() for this engine, the + statement will be executed in the context of the current transactional connection. + Otherwise, a commit() will be performed immediately after execution, since the local + pooled connection is returned to the pool after execution without a transaction set + up. + + In all error cases, a rollback() is immediately performed on the connection before + propigating the exception outwards. + + Other options include: + + connection - a DBAPI connection to use for the execute. If None, a connection is + pulled from this engine's connection pool. + + echo - enables echo for this execution, which causes all SQL and parameters + to be dumped to the engine's logging output before execution. + + typemap - a map of column names mapped to sqlalchemy.types.TypeEngine objects. + These will be passed to the created ResultProxy to perform + post-processing on result-set values. + + commit - if True, will automatically commit the statement after completion. """ + if parameters is None: + parameters = {} + + if connection is None: + connection = self.connection() + + if cursor is None: + cursor = connection.cursor() + + try: if echo is True or self.echo is not False: self.log(statement) self.log(repr(parameters)) @@ -359,18 +426,10 @@ class SQLEngine(schema.SchemaEngine): self._executemany(cursor, statement, parameters) else: self._execute(cursor, statement, parameters) - return cursor - - try: - self.pre_exec(proxy, statement, parameters, **kwargs) - self._process_defaults(proxy, statement, parameters, **kwargs) - proxy(statement, parameters) - self.post_exec(proxy, statement, parameters, **kwargs) - if commit or self.context.transaction is None: + if self.context.transaction is None: self.do_commit(connection) except: self.do_rollback(connection) - # TODO: wrap DB exceptions ? raise return ResultProxy(cursor, self, typemap = typemap) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 606bcf508d..13ed33e828 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -187,6 +187,7 @@ class Column(SchemaItem): self._impl = self.table.engine.columnimpl(self) if self.default is not None: + self.default = ColumnDefault(self.default) self._init_items(self.default) self._init_items(*self.args) self.args = None diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 86412c2dbd..99755ae6c9 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -234,17 +234,37 @@ class Compiled(ClauseVisitor): object be dependent on the actual values of those bind parameters, even though it may reference those values as defaults.""" - def __init__(self, engine, statement, bindparams): + def __init__(self, engine, statement, parameters): + """constructs a new Compiled object. + + engine - SQLEngine to compile against + + statement - ClauseElement to be compiled + + parameters - optional dictionary indicating a set of bind parameters + specified with this Compiled object. These parameters are the "default" + values corresponding to the ClauseElement's BindParamClauses when the Compiled + is executed. In the case of an INSERT or UPDATE statement, these parameters + will also result in the creation of new BindParamClause objects for each key + and will also affect the generated column list in an INSERT statement and the SET + clauses of an UPDATE statement. The keys of the parameter dictionary can + either be the string names of columns or actual sqlalchemy.schema.Column objects.""" self.engine = engine - self.bindparams = bindparams + self.parameters = parameters self.statement = statement def __str__(self): """returns the string text of the generated SQL statement.""" raise NotImplementedError() def get_params(self, **params): - """returns the bind params for this compiled object, with values overridden by - those given in the **params dictionary""" + """returns the bind params for this compiled object. + + Will start with the default parameters specified when this Compiled object + was first constructed, and will override those values with those sent via + **params, which are key/value pairs. Each key should match one of the + BindParamClause objects compiled into this object; either the "key" or + "shortname" property of the BindParamClause. + """ raise NotImplementedError() def execute(self, *multiparams, **params): @@ -254,7 +274,7 @@ class Compiled(ClauseVisitor): else: params = self.get_params(**params) - return self.engine.execute(str(self), params, compiled=self, typemap=self.typemap) + return self.engine.execute_compiled(self, params) def scalar(self, *multiparams, **params): """executes this compiled object via the execute() method, then @@ -326,7 +346,7 @@ class ClauseElement(object): return [self] columns = property(lambda s: s._get_columns()) - def compile(self, engine = None, bindparams = None, typemap=None): + def compile(self, engine = None, parameters = None, typemap=None): """compiles this SQL expression using its underlying SQLEngine to produce a Compiled object. If no engine can be found, an ansisql engine is used. bindparams is a dictionary representing the default bind parameters to be used with @@ -337,7 +357,7 @@ class ClauseElement(object): if engine is None: raise "no SQLEngine could be located within this ClauseElement." - return engine.compile(self, bindparams = bindparams, typemap=typemap) + return engine.compile(self, parameters=parameters, typemap=typemap) def __str__(self): e = self.engine @@ -355,7 +375,7 @@ class ClauseElement(object): bindparams = multiparams[0] else: bindparams = params - c = self.compile(e, bindparams = bindparams) + c = self.compile(e, parameters=bindparams) return c.execute(*multiparams, **params) def scalar(self, *multiparams, **params): diff --git a/test/query.py b/test/query.py index 8fc9694f45..75088da57f 100644 --- a/test/query.py +++ b/test/query.py @@ -57,6 +57,23 @@ class QueryTest(PersistTest): print repr(users_with_date.select().execute().fetchall()) users_with_date.drop() + def testdefaults(self): + x = {'x':0} + def mydefault(): + x['x'] += 1 + return x['x'] + + t = Table('default_test1', db, + Column('col1', Integer, primary_key=True, default=mydefault), + Column('col2', String(20), default="imthedefault"), + Column('col3', String(20), default=func.count(1)), + ) + t.create() + t.insert().execute() + t.insert().execute() + t.insert().execute() + t.drop() + def testdelete(self): c = db.connection() diff --git a/test/select.py b/test/select.py index b144f8804a..8a0027beaa 100644 --- a/test/select.py +++ b/test/select.py @@ -357,7 +357,7 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable # check that the bind params sent along with a compile() call # get preserved when the params are retreived later s = select([table], table.c.id == bindparam('test')) - c = s.compile(bindparams = {'test' : 7}) + c = s.compile(parameters = {'test' : 7}) self.assert_(c.get_params() == {'test' : 7}) def testcorrelatedsubquery(self): @@ -425,7 +425,7 @@ class CRUDTest(SQLTest): self.runtest(update(table, table.c.id == 12, values = {table.c.id : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'mytable_myid': 12, 'myid': 9, 'description': 'test'}) s = table.update(table.c.id == 12, values = {table.c.name : 'lala'}) print str(s) - c = s.compile(bindparams = {'mytable_id':9,'name':'h0h0'}) + c = s.compile(parameters = {'mytable_id':9,'name':'h0h0'}) print str(c) self.assert_(str(s) == str(c)) diff --git a/test/tables.py b/test/tables.py index d0d0692488..807ecf7648 100644 --- a/test/tables.py +++ b/test/tables.py @@ -12,7 +12,7 @@ db = testbase.db users = Table('users', db, - Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), + Column('user_id', Integer, Sequence('user_id_seq', optional=False), primary_key = True), Column('user_name', String(40)), ) diff --git a/test/testbase.py b/test/testbase.py index 435e26b762..df4c186c3c 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -81,8 +81,8 @@ class EngineAssert(object): """decorates a SQLEngine object to match the incoming queries against a set of assertions.""" def __init__(self, engine): self.engine = engine - self.realexec = engine.execute - engine.execute = self.execute + self.realexec = engine.execute_compiled + engine.execute_compiled = self.execute_compiled self.echo = engine.echo self.logger = engine.logger self.set_assert_list(None, None) @@ -93,9 +93,10 @@ class EngineAssert(object): self.assert_list = list if list is not None: self.assert_list.reverse() - def execute(self, statement, parameters, **kwargs): + def execute_compiled(self, compiled, parameters, **kwargs): self.engine.echo = self.echo self.engine.logger = self.logger + statement = str(compiled) if self.assert_list is not None: item = self.assert_list.pop() @@ -104,14 +105,7 @@ class EngineAssert(object): params = params() # deal with paramstyles of different engines - if isinstance(self.engine, sqlite.SQLiteSQLEngine): - paramstyle = 'named' - else: - db = self.engine.dbapi() - if db is not None: - paramstyle = db.paramstyle - else: - paramstyle = 'named' + paramstyle = self.engine.paramstyle if paramstyle == 'named': pass elif paramstyle =='pyformat': @@ -127,31 +121,10 @@ class EngineAssert(object): elif paramstyle=='numeric': repl = None counter = 0 - def append_arg(match): - names.append(match.group(1)) - if repl is None: - counter += 1 - return counter - else: - return repl - # substitute bind string in query, translate bind param - # dict to a list (or a list of dicts to a list of lists) - query = re.sub(r':([\w_]+)', append_arg, query) - if isinstance(params, list): - args = [] - for p in params: - l = [] - args.append(l) - for n in names: - l.append(p[n]) - else: - args = [] - for n in names: - args.append(params[n]) - params = args + query = re.sub(r':([\w_]+)', repl, query) self.unittest.assert_(statement == query and params == parameters, "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) - return self.realexec(statement, parameters, **kwargs) + return self.realexec(compiled, parameters, **kwargs) class TTestSuite(unittest.TestSuite): -- 2.47.2