From 5a70cb7fa5345696731eac48958bee804f241df3 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 5 Nov 2007 19:23:08 +0000 Subject: [PATCH] - adjustments to oracle ROWID logic...recent oid changes mean we have to use "rowid" against the select itself (i.e. its just...'rowid', no table name). seems to work OK but not sure if issues will arise - fixes to oracle bind param stuff to account for recent removal of ClauseParameters object. --- CHANGES | 4 ++++ lib/sqlalchemy/databases/oracle.py | 13 ++++++++----- lib/sqlalchemy/engine/default.py | 28 +++++++++++++++------------- lib/sqlalchemy/sql/compiler.py | 30 ++++++++++++++++-------------- test/dialect/oracle.py | 6 +++--- test/orm/eager_relations.py | 2 +- 6 files changed, 47 insertions(+), 36 deletions(-) diff --git a/CHANGES b/CHANGES index 7c0c77296e..c02f0ebb12 100644 --- a/CHANGES +++ b/CHANGES @@ -121,6 +121,10 @@ CHANGES - Added test coverage for unknown type reflection. Fixed sqlite/mysql handling of type reflection for unknown types. + - oracle uses plain "rowid" name when limiting against subqueries, + since this is the "rowid" of the enclosing query. if this raises + issues in the wild please file tickets ! + - misc - Removed unused util.hash(). diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 0197900a58..4def88afa2 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -204,8 +204,10 @@ class OracleExecutionContext(default.DefaultExecutionContext): if self.dialect.auto_setinputsizes: self.set_input_sizes() if self.compiled_parameters is not None and len(self.compiled_parameters) == 1: - for key in self.compiled_parameters[0]: - (bindparam, name, value) = self.compiled_parameters[0].get_parameter(key) + for key in self.compiled.binds: + bindparam = self.compiled.binds[key] + name = self.compiled.bind_names[bindparam] + value = self.compiled_parameters[0][name] if bindparam.isoutparam: dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) if not hasattr(self, 'out_parameters'): @@ -216,9 +218,10 @@ class OracleExecutionContext(default.DefaultExecutionContext): def get_result_proxy(self): if hasattr(self, 'out_parameters'): if self.compiled_parameters is not None and len(self.compiled_parameters) == 1: - for k in self.out_parameters: - type = self.compiled_parameters[0].get_type(k) - self.out_parameters[k] = type.dialect_impl(self.dialect).result_processor(self.dialect)(self.out_parameters[k].getvalue()) + for bind, name in self.compiled.bind_names.iteritems(): + if name in self.out_parameters: + type = bind.type + self.out_parameters[name] = type.dialect_impl(self.dialect).result_processor(self.dialect)(self.out_parameters[name].getvalue()) else: for k in self.out_parameters: self.out_parameters[k] = self.out_parameters[k].getvalue() diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 771ca06f9a..198f6742bf 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -311,24 +311,26 @@ class DefaultExecutionContext(base.ExecutionContext): from the bind parameter's ``TypeEngine`` objects. """ - plist = self.compiled_parameters + types = dict([ + (self.compiled.bind_names[bindparam], bindparam.type) + for bindparam in self.compiled.bind_names + ]) + if self.dialect.positional: inputsizes = [] - for params in plist[0:1]: - for key in params.positional: - typeengine = params.get_type(key) - dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if dbtype is not None: - inputsizes.append(dbtype) + for key in self.compiled.positiontup: + typeengine = types[key] + dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + if dbtype is not None: + inputsizes.append(dbtype) self.cursor.setinputsizes(*inputsizes) else: inputsizes = {} - for params in plist[0:1]: - for key in params.keys(): - typeengine = params.get_type(key) - dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if dbtype is not None: - inputsizes[key.encode(self.dialect.encoding)] = dbtype + for key in self.compiled.bind_names.values(): + typeengine = types[key] + dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + if dbtype is not None: + inputsizes[key.encode(self.dialect.encoding)] = dbtype self.cursor.setinputsizes(**inputsizes) def __process_defaults(self): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index ef66ffd5a6..9c8a6f56e3 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -235,22 +235,24 @@ class DefaultCompiler(engine.Compiled): # for this column which is used to translate result set values self.typemap.setdefault(name.lower(), column.type) self.column_labels.setdefault(column._label, name.lower()) - - if column.table is None or not column.table.named_with_column(): - return self.preparer.format_column(column, name=name) - else: - if column.table.oid_column is column: - n = self.dialect.oid_column_name(column) - if n is not None: - return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)), n) - elif len(column.table.primary_key) != 0: - pk = list(column.table.primary_key)[0] - pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name)) - return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + + if column._is_oid: + n = self.dialect.oid_column_name(column) + if n is not None: + if column.table is None or not column.table.named_with_column(): + return self.preparer.format_column(column, name=n) else: - return None + return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)), n) + elif len(column.table.primary_key) != 0: + pk = list(column.table.primary_key)[0] + pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name)) + return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) else: - return self.preparer.format_column_with_table(column, column_name=name, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + return None + elif column.table is None or not column.table.named_with_column(): + return self.preparer.format_column(column, name=name) + else: + return self.preparer.format_column_with_table(column, column_name=name, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) def visit_fromclause(self, fromclause, **kwargs): diff --git a/test/dialect/oracle.py b/test/dialect/oracle.py index 03b68d0632..7d544fca54 100644 --- a/test/dialect/oracle.py +++ b/test/dialect/oracle.py @@ -46,17 +46,17 @@ class CompileTest(SQLCompileTest): s = select([t]).limit(10).offset(20) self.assert_compile(s, "SELECT col1, col2 FROM (SELECT sometable.col1 AS col1, sometable.col2 AS col2, " - "ROW_NUMBER() OVER (ORDER BY sometable.rowid) AS ora_rn FROM sometable) WHERE ora_rn>20 AND ora_rn<=30" + "ROW_NUMBER() OVER (ORDER BY rowid) AS ora_rn FROM sometable) WHERE ora_rn>20 AND ora_rn<=30" ) s = select([s.c.col1, s.c.col2]) self.assert_compile(s, "SELECT col1, col2 FROM (SELECT col1, col2 FROM (SELECT sometable.col1 AS col1, " - "sometable.col2 AS col2, ROW_NUMBER() OVER (ORDER BY sometable.rowid) AS ora_rn FROM sometable) WHERE ora_rn>20 AND ora_rn<=30)") + "sometable.col2 AS col2, ROW_NUMBER() OVER (ORDER BY rowid) AS ora_rn FROM sometable) WHERE ora_rn>20 AND ora_rn<=30)") # testing this twice to ensure oracle doesn't modify the original statement self.assert_compile(s, "SELECT col1, col2 FROM (SELECT col1, col2 FROM (SELECT sometable.col1 AS col1, " - "sometable.col2 AS col2, ROW_NUMBER() OVER (ORDER BY sometable.rowid) AS ora_rn FROM sometable) WHERE ora_rn>20 AND ora_rn<=30)") + "sometable.col2 AS col2, ROW_NUMBER() OVER (ORDER BY rowid) AS ora_rn FROM sometable) WHERE ora_rn>20 AND ora_rn<=30)") def test_outer_join(self): table1 = table('mytable', diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index 8f42b5128d..61c41dd7e9 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -463,7 +463,7 @@ class SelfReferentialEagerTest(ORMTest): def define_tables(self, metadata): global nodes nodes = Table('nodes', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, Sequence('node_id_seq', optional=True), primary_key=True), Column('parent_id', Integer, ForeignKey('nodes.id')), Column('data', String(30))) -- 2.47.2