From: Mike Bayer Date: Mon, 13 Mar 2006 02:00:21 +0000 (+0000) Subject: oracle is requiring dictionary params to be in a clean dict, added conversion X-Git-Tag: rel_0_1_4~9 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1fc3d226558837b69ea8234f405b8dd6c8710b0a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git oracle is requiring dictionary params to be in a clean dict, added conversion some fixes to unit tests --- diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index e44e0a9509..8a29b847fc 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -586,14 +586,7 @@ class SQLEngine(schema.SchemaEngine): if statement is None: return cursor - executemany = parameters is not None and isinstance(parameters, list) - - if self.positional: - if executemany: - parameters = [p.values() for p in parameters] - else: - parameters = parameters.values() - + parameters = self._convert_compiled_params(parameters) self.execute(statement, parameters, connection=connection, cursor=cursor, return_raw=True) return cursor @@ -657,12 +650,12 @@ class SQLEngine(schema.SchemaEngine): return ResultProxy(cursor, self, typemap=typemap) def _execute(self, c, statement, parameters): + if parameters is None: + if self.positional: + parameters = () + else: + parameters = {} try: - if parameters is None: - if self.positional: - parameters = () - else: - parameters = {} c.execute(statement, parameters) except Exception, e: raise exceptions.SQLError(statement, parameters, e) @@ -671,15 +664,26 @@ class SQLEngine(schema.SchemaEngine): c.executemany(statement, parameters) self.context.rowcount = c.rowcount - def proxy(self, statement=None, parameters=None): + def _convert_compiled_params(self, parameters): executemany = parameters is not None and isinstance(parameters, list) + # the bind params are a CompiledParams object. but all the DBAPI's hate + # that object (or similar). so convert it to a clean + # dictionary/list/tuple of dictionary/tuple of list + if parameters is not None: + if self.positional: + if executemany: + parameters = [p.values() for p in parameters] + else: + parameters = parameters.values() + else: + if executemany: + parameters = [p.get_raw_dict() for p in parameters] + else: + parameters = parameters.get_raw_dict() + return parameters - if self.positional: - if executemany: - parameters = [p.values() for p in parameters] - else: - parameters = parameters.values() - + def proxy(self, statement=None, parameters=None): + parameters = self._convert_compiled_params(parameters) return self.execute(statement, parameters) def log(self, msg): diff --git a/lib/sqlalchemy/exceptions.py b/lib/sqlalchemy/exceptions.py index 6883293ec5..e270225d8f 100644 --- a/lib/sqlalchemy/exceptions.py +++ b/lib/sqlalchemy/exceptions.py @@ -17,6 +17,8 @@ class SQLError(SQLAlchemyError): self.statement = statement self.params = params self.orig = orig + def __str__(self): + return SQLAlchemyError.__str__(self) + " " + repr(self.statement) + " " + repr(self.params) class ArgumentError(SQLAlchemyError): """raised for all those conditions where invalid arguments are sent to constructed @@ -38,4 +40,4 @@ class AssertionError(SQLAlchemyError): class DBAPIError(SQLAlchemyError): """something weird happened with a particular DBAPI version""" - pass \ No newline at end of file + pass diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index a6fe3e8800..7ed013dfcf 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -252,6 +252,11 @@ class ClauseParameters(util.OrderedDict): return [self[key] for key in self] def get_original_dict(self): return self.copy() + def get_raw_dict(self): + d = {} + for k in self: + d[k] = self[k] + return d class ClauseVisitor(object): """Defines the visiting of ClauseElements.""" diff --git a/test/objectstore.py b/test/objectstore.py index 0a02d6fb11..312e1b53f1 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -144,7 +144,7 @@ class SessionTest(AssertMixin): class UnicodeTest(AssertMixin): def setUpAll(self): global uni_table - uni_table = Table('test', db, + uni_table = Table('uni_test', db, Column('id', Integer, primary_key=True), Column('txt', Unicode(50))).create() diff --git a/test/testtypes.py b/test/testtypes.py index d44c439bdb..b874fe0afe 100644 --- a/test/testtypes.py +++ b/test/testtypes.py @@ -96,7 +96,7 @@ class UnicodeTest(AssertMixin): def setUpAll(self): global unicode_table unicode_table = Table('unicode_table', db, - Column('id', Integer, primary_key=True), + Column('id', Integer, Sequence('uni_id_seq', optional=True), primary_key=True), Column('unicode_data', Unicode(250)), Column('plain_data', String(250)) )