From 9b22fc0a9b5a6e97129096dd5ee8b3eb24895ac4 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 16 Oct 2007 16:03:59 +0000 Subject: [PATCH] - Fixed SQL compiler's awareness of top-level column labels as used in result-set processing; nested selects which contain the same column names don't affect the result or conflict with result-column metadata. - query.get() and related functions (like many-to-one lazyloading) use compile-time-aliased bind parameter names, to prevent name conflicts with bind parameters that already exist in the mapped selectable. --- CHANGES | 9 +++++++++ lib/sqlalchemy/orm/mapper.py | 7 +++++-- lib/sqlalchemy/orm/query.py | 5 +++-- lib/sqlalchemy/orm/strategies.py | 6 +++--- lib/sqlalchemy/sql/compiler.py | 4 ++-- test/orm/mapper.py | 6 +++--- test/orm/query.py | 15 +++++++++++++++ test/sql/query.py | 13 ++++++++++++- 8 files changed, 52 insertions(+), 13 deletions(-) diff --git a/CHANGES b/CHANGES index b337d7a496..fe25f6662a 100644 --- a/CHANGES +++ b/CHANGES @@ -30,6 +30,15 @@ CHANGES . FBDialect.table_names() doesn't bring system tables (ticket:796). . FB now reflects Column's nullable property correctly. +- Fixed SQL compiler's awareness of top-level column labels as used + in result-set processing; nested selects which contain the same column + names don't affect the result or conflict with result-column metadata. + +- query.get() and related functions (like many-to-one lazyloading) + use compile-time-aliased bind parameter names, to prevent + name conflicts with bind parameters that already exist in the + mapped selectable. + - Fixed three- and multi-level select and deferred inheritance loading (i.e. abc inheritance with no select_table), [ticket:795] diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 5e20cf6b64..b68b4c8fe9 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -464,9 +464,12 @@ class Mapper(object): self.__log("Identified primary key columns: " + str(primary_key)) _get_clause = sql.and_() + _get_params = {} for primary_key in self.primary_key: - _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type_=primary_key.type, unique=True)) - self._get_clause = _get_clause + bind = sql.bindparam(None, type_=primary_key.type) + _get_params[primary_key] = bind + _get_clause.clauses.append(primary_key == bind) + self._get_clause = (_get_clause, _get_params) def _get_equivalent_columns(self): """Create a map of all *equivalent* columns, based on diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 0d04b768cb..f6268579f2 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -708,16 +708,17 @@ class Query(object): ident = util.to_list(ident) params = {} + (_get_clause, _get_params) = self.select_mapper._get_clause for i, primary_key in enumerate(self.primary_key_columns): try: - params[primary_key._label] = ident[i] + params[_get_params[primary_key].key] = ident[i] except IndexError: raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns])) try: q = self if lockmode is not None: q = q.with_lockmode(lockmode) - q = q.filter(self.select_mapper._get_clause) + q = q.filter(_get_clause) q = q.params(params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None)) # call using all() to avoid LIMIT compilation complexity return q.all()[0] diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 5725b5f8eb..716a6dbba5 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -200,11 +200,11 @@ class DeferredColumnLoader(LoaderStrategy): raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key)) if create_statement is None: - clause = localparent._get_clause + (clause, param_map) = localparent._get_clause ident = instance._instance_key[1] params = {} for i, primary_key in enumerate(localparent.primary_key): - params[primary_key._label] = ident[i] + params[param_map[primary_key].key] = ident[i] statement = sql.select([p.columns[0] for p in group], clause, from_obj=[localparent.mapped_table], use_labels=True) else: statement, params = create_statement() @@ -294,7 +294,7 @@ class LazyLoader(AbstractRelationLoader): # determine if our "lazywhere" clause is the same as the mapper's # get() clause. then we can just use mapper.get() #from sqlalchemy.orm import query - self.use_get = not self.uselist and self.mapper._get_clause.compare(self.lazywhere) + self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.lazywhere) if self.use_get: self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads") diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 05c7a5cf41..5572c2ed4e 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -240,7 +240,7 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): def visit_label(self, label): labelname = self._truncated_identifier("colident", label.name) - if self.stack and self.stack[-1].get('select'): + if len(self.stack) == 1 and self.stack[-1].get('select'): self.typemap.setdefault(labelname.lower(), label.obj.type) if isinstance(label.obj, sql._ColumnClause): self.column_labels[label.obj._label] = labelname @@ -258,7 +258,7 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): else: name = column.name - if self.stack and self.stack[-1].get('select'): + if len(self.stack) == 1 and self.stack[-1].get('select'): # if we are within a visit to a Select, set up the "typemap" # for this column which is used to translate result set values self.typemap.setdefault(name.lower(), column.type) diff --git a/test/orm/mapper.py b/test/orm/mapper.py index d02feac47e..e723f968d1 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -728,7 +728,7 @@ class DeferredTest(MapperSuperTest): orderby = str(orders.default_order_by()[0].compile(bind=testbase.db)) self.assert_sql(testbase.db, go, [ ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}), - ("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3}) + ("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :param_1", {'param_1':3}) ]) def testunsaved(self): @@ -791,7 +791,7 @@ class DeferredTest(MapperSuperTest): orderby = str(orders.default_order_by()[0].compile(testbase.db)) self.assert_sql(testbase.db, go, [ ("SELECT orders.order_id AS orders_order_id FROM orders ORDER BY %s" % orderby, {}), - ("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3}) + ("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders WHERE orders.order_id = :param_1", {'param_1':3}) ]) o2 = q.select()[2] @@ -838,7 +838,7 @@ class DeferredTest(MapperSuperTest): orderby = str(orders.default_order_by()[0].compile(testbase.db)) self.assert_sql(testbase.db, go, [ ("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}), - ("SELECT orders.user_id AS orders_user_id FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3}) + ("SELECT orders.user_id AS orders_user_id FROM orders WHERE orders.order_id = :param_1", {'param_1':3}) ]) sess.clear() q3 = q2.options(undefer('user_id')) diff --git a/test/orm/query.py b/test/orm/query.py index 49d3852e2f..547c7ce884 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -65,6 +65,21 @@ class GetTest(QueryTest): u2 = s.query(User).get(7) assert u is not u2 + def test_unique_param_names(self): + class SomeUser(object): + pass + s = users.select(users.c.id!=12).alias('users') + m = mapper(SomeUser, s) + print s.primary_key + print m.primary_key + assert s.primary_key == m.primary_key + + row = s.select(use_labels=True).execute().fetchone() + print row[s.primary_key[0]] + + sess = create_session() + assert sess.query(SomeUser).get(7).name == 'jack' + def test_load(self): s = create_session() diff --git a/test/sql/query.py b/test/sql/query.py index eebfb7c081..77e1421a53 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -362,7 +362,18 @@ class QueryTest(PersistTest): except exceptions.InvalidRequestError, e: assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement." or \ str(e) == "Ambiguous column name 'USER_ID' in result set! try 'use_labels' option on select statement." - + + def test_column_label_targeting(self): + users.insert().execute(user_id=7, user_name='ed') + + for s in ( + users.select().alias('foo'), + users.select().alias(users.name), + ): + row = s.select(use_labels=True).execute().fetchone() + assert row[s.c.user_id] == 7 + assert row[s.c.user_name] == 'ed' + def test_keys(self): users.insert().execute(user_id=1, user_name='foo') r = users.select().execute().fetchone() -- 2.47.3