]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed SQL compiler's awareness of top-level column labels as used
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 16 Oct 2007 16:03:59 +0000 (16:03 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 16 Oct 2007 16:03:59 +0000 (16:03 +0000)
  in result-set processing; nested selects which contain the same column
  names don't affect the result or conflict with result-column metadata.

- query.get() and related functions (like many-to-one lazyloading)
  use compile-time-aliased bind parameter names, to prevent
  name conflicts with bind parameters that already exist in the
  mapped selectable.

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/compiler.py
test/orm/mapper.py
test/orm/query.py
test/sql/query.py

diff --git a/CHANGES b/CHANGES
index b337d7a496f11825f9b454655bf8c95a243d3658..fe25f6662aa545481963d19193c777212a7ef675 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -30,6 +30,15 @@ CHANGES
   . FBDialect.table_names() doesn't bring system tables (ticket:796).
   . FB now reflects Column's nullable property correctly.
 
+- Fixed SQL compiler's awareness of top-level column labels as used
+  in result-set processing; nested selects which contain the same column 
+  names don't affect the result or conflict with result-column metadata.
+
+- query.get() and related functions (like many-to-one lazyloading)
+  use compile-time-aliased bind parameter names, to prevent
+  name conflicts with bind parameters that already exist in the 
+  mapped selectable.
+
 - Fixed three- and multi-level select and deferred inheritance loading
   (i.e. abc inheritance with no select_table), [ticket:795]
 
index 5e20cf6b641b2688adceec18ce158ce194124261..b68b4c8fe951c940d5fa0575decd466577a82b01 100644 (file)
@@ -464,9 +464,12 @@ class Mapper(object):
             self.__log("Identified primary key columns: " + str(primary_key))
         
             _get_clause = sql.and_()
+            _get_params = {}
             for primary_key in self.primary_key:
-                _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type_=primary_key.type, unique=True))
-            self._get_clause = _get_clause
+                bind = sql.bindparam(None, type_=primary_key.type)
+                _get_params[primary_key] = bind
+                _get_clause.clauses.append(primary_key == bind)
+            self._get_clause = (_get_clause, _get_params)
 
     def _get_equivalent_columns(self):
         """Create a map of all *equivalent* columns, based on
index 0d04b768cb455379c67acf32524d782e19627a7c..f6268579f20955b2cd94eb607ae09e44d5c7e163 100644 (file)
@@ -708,16 +708,17 @@ class Query(object):
             ident = util.to_list(ident)
         params = {}
         
+        (_get_clause, _get_params) = self.select_mapper._get_clause
         for i, primary_key in enumerate(self.primary_key_columns):
             try:
-                params[primary_key._label] = ident[i]
+                params[_get_params[primary_key].key] = ident[i]
             except IndexError:
                 raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in self.primary_key_columns]))
         try:
             q = self
             if lockmode is not None:
                 q = q.with_lockmode(lockmode)
-            q = q.filter(self.select_mapper._get_clause)
+            q = q.filter(_get_clause)
             q = q.params(params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None))
             # call using all() to avoid LIMIT compilation complexity
             return q.all()[0]
index 5725b5f8eb4c44e906b9e39250fde43332016951..716a6dbba544383390affda96ffd7778ce73abf0 100644 (file)
@@ -200,11 +200,11 @@ class DeferredColumnLoader(LoaderStrategy):
                 raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
 
             if create_statement is None:
-                clause = localparent._get_clause
+                (clause, param_map) = localparent._get_clause
                 ident = instance._instance_key[1]
                 params = {}
                 for i, primary_key in enumerate(localparent.primary_key):
-                    params[primary_key._label] = ident[i]
+                    params[param_map[primary_key].key] = ident[i]
                 statement = sql.select([p.columns[0] for p in group], clause, from_obj=[localparent.mapped_table], use_labels=True)
             else:
                 statement, params = create_statement()
@@ -294,7 +294,7 @@ class LazyLoader(AbstractRelationLoader):
         # determine if our "lazywhere" clause is the same as the mapper's
         # get() clause.  then we can just use mapper.get()
         #from sqlalchemy.orm import query
-        self.use_get = not self.uselist and self.mapper._get_clause.compare(self.lazywhere)
+        self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.lazywhere)
         if self.use_get:
             self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads")
 
index 05c7a5cf413f3ec689963f1bf00ebe9a0db5aacd..5572c2ed4eff29687aa3e3ddefc246a3779c36a9 100644 (file)
@@ -240,7 +240,7 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
     def visit_label(self, label):
         labelname = self._truncated_identifier("colident", label.name)
         
-        if self.stack and self.stack[-1].get('select'):
+        if len(self.stack) == 1 and self.stack[-1].get('select'):
             self.typemap.setdefault(labelname.lower(), label.obj.type)
             if isinstance(label.obj, sql._ColumnClause):
                 self.column_labels[label.obj._label] = labelname
@@ -258,7 +258,7 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
         else:
             name = column.name
 
-        if self.stack and self.stack[-1].get('select'):
+        if len(self.stack) == 1 and self.stack[-1].get('select'):
             # if we are within a visit to a Select, set up the "typemap"
             # for this column which is used to translate result set values
             self.typemap.setdefault(name.lower(), column.type)
index d02feac47ec4ab2bf4488c6ceb57c940a0caa9b0..e723f968d10533132656c1705061a887bea29b89 100644 (file)
@@ -728,7 +728,7 @@ class DeferredTest(MapperSuperTest):
         orderby = str(orders.default_order_by()[0].compile(bind=testbase.db))
         self.assert_sql(testbase.db, go, [
             ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
-            ("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
+            ("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :param_1", {'param_1':3})
         ])
 
     def testunsaved(self):
@@ -791,7 +791,7 @@ class DeferredTest(MapperSuperTest):
         orderby = str(orders.default_order_by()[0].compile(testbase.db))
         self.assert_sql(testbase.db, go, [
             ("SELECT orders.order_id AS orders_order_id FROM orders ORDER BY %s" % orderby, {}),
-            ("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
+            ("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders WHERE orders.order_id = :param_1", {'param_1':3})
         ])
         
         o2 = q.select()[2]
@@ -838,7 +838,7 @@ class DeferredTest(MapperSuperTest):
         orderby = str(orders.default_order_by()[0].compile(testbase.db))
         self.assert_sql(testbase.db, go, [
             ("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
-            ("SELECT orders.user_id AS orders_user_id FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
+            ("SELECT orders.user_id AS orders_user_id FROM orders WHERE orders.order_id = :param_1", {'param_1':3})
         ])
         sess.clear()
         q3 = q2.options(undefer('user_id'))
index 49d3852e2f7797ea3ba27745df9c5fd73459e24f..547c7ce8842fe71b7bb9af5dcd5f0d1a0c89a594 100644 (file)
@@ -65,6 +65,21 @@ class GetTest(QueryTest):
         u2 = s.query(User).get(7)
         assert u is not u2
 
+    def test_unique_param_names(self):
+        class SomeUser(object):
+            pass
+        s = users.select(users.c.id!=12).alias('users')
+        m = mapper(SomeUser, s)
+        print s.primary_key
+        print m.primary_key
+        assert s.primary_key == m.primary_key
+        
+        row = s.select(use_labels=True).execute().fetchone()
+        print row[s.primary_key[0]]
+         
+        sess = create_session()
+        assert sess.query(SomeUser).get(7).name == 'jack'
+
     def test_load(self):
         s = create_session()
         
index eebfb7c0817b9ab816c513865ecbdb710276bc7e..77e1421a537b6e13faefb002af0098dc8228f141 100644 (file)
@@ -362,7 +362,18 @@ class QueryTest(PersistTest):
         except exceptions.InvalidRequestError, e:
             assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement." or \
                    str(e) == "Ambiguous column name 'USER_ID' in result set! try 'use_labels' option on select statement."
-            
+    
+    def test_column_label_targeting(self):
+        users.insert().execute(user_id=7, user_name='ed')
+
+        for s in (
+            users.select().alias('foo'),
+            users.select().alias(users.name),
+        ):
+            row = s.select(use_labels=True).execute().fetchone()
+            assert row[s.c.user_id] == 7
+            assert row[s.c.user_name] == 'ed'
+        
     def test_keys(self):
         users.insert().execute(user_id=1, user_name='foo')
         r = users.select().execute().fetchone()