]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added "adapt_on_names" boolean flag to orm.aliased()
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Sep 2011 20:48:39 +0000 (16:48 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Sep 2011 20:48:39 +0000 (16:48 -0400)
    construct.  Allows an aliased() construct
    to link the ORM entity to a selectable that contains
    aggregates or other derived forms of a particular
    attribute, provided the name is the same as that
    of the entity mapped column.

CHANGES
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/util.py
test/orm/test_froms.py
test/orm/test_query.py

diff --git a/CHANGES b/CHANGES
index 642fc50c3c6bdad905b4e7ef0ef5f85804f02c50..7c6efecc402f6522bcf846c87ba340eaa4c7a495 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -14,6 +14,13 @@ CHANGES
      with the Session to proceed after a rollback
      when the Session.is_active is True.
      [ticket:2241]
+  - added "adapt_on_names" boolean flag to orm.aliased()
+    construct.  Allows an aliased() construct
+    to link the ORM entity to a selectable that contains
+    aggregates or other derived forms of a particular
+    attribute, provided the name is the same as that
+    of the entity mapped column.
 
   - Fixed a variety of synonym()-related regressions
     from 0.6:
index 4708852ea1fde3b84defafe5bb4db6e575e67ec8..f901d0a0be9f2e73f7a0be32da19604dfaa62cb9 100644 (file)
@@ -221,15 +221,61 @@ class AliasedClass(object):
         session.query(User, user_alias).\\
                         join((user_alias, User.id > user_alias.id)).\\
                         filter(User.name==user_alias.name)
-
+    
+    The resulting object is an instance of :class:`.AliasedClass`, however
+    it implements a ``__getattribute__()`` scheme which will proxy attribute
+    access to that of the ORM class being aliased.  All classmethods
+    on the mapped entity should also be available here, including 
+    hybrids created with the :ref:`hybrids_toplevel` extension,
+    which will receive the :class:`.AliasedClass` as the "class" argument
+    when classmethods are called.
+    
+    :param cls: ORM mapped entity which will be "wrapped" around an alias.
+    :param alias: a selectable, such as an :func:`.alias` or :func:`.select`
+     construct, which will be rendered in place of the mapped table of the
+     ORM entity.  If left as ``None``, an ordinary :class:`.Alias` of the 
+     ORM entity's mapped table will be generated.
+    :param name: A name which will be applied both to the :class:`.Alias`
+     if one is generated, as well as the name present in the "named tuple"
+     returned by the :class:`.Query` object when results are returned.
+    :param adapt_on_names: if True, more liberal "matching" will be used when
+     mapping the mapped columns of the ORM entity to those of the given selectable - 
+     a name-based match will be performed if the given selectable doesn't 
+     otherwise have a column that corresponds to one on the entity.  The 
+     use case for this is when associating an entity with some derived
+     selectable such as one that uses aggregate functions::
+     
+        class UnitPrice(Base):
+            __tablename__ = 'unit_price'
+            ...
+            unit_id = Column(Integer)
+            price = Column(Numeric)
+        
+        aggregated_unit_price = Session.query(
+                                    func.sum(UnitPrice.price).label('price')
+                                ).group_by(UnitPrice.unit_id).subquery()
+                                
+        aggregated_unit_price = aliased(UnitPrice, alias=aggregated_unit_price, adapt_on_names=True)
+    
+     Above, functions on ``aggregated_unit_price`` which
+     refer to ``.price`` will return the
+     ``fund.sum(UnitPrice.price).label('price')`` column,
+     as it is matched on the name "price".  Ordinarily, the "price" function wouldn't
+     have any "column correspondence" to the actual ``UnitPrice.price`` column
+     as it is not a proxy of the original.
+     
+     ``adapt_on_names`` is new in 0.7.3.
+        
     """
-    def __init__(self, cls, alias=None, name=None):
+    def __init__(self, cls, alias=None, name=None, adapt_on_names=False):
         self.__mapper = _class_to_mapper(cls)
         self.__target = self.__mapper.class_
+        self.__adapt_on_names = adapt_on_names
         if alias is None:
             alias = self.__mapper._with_polymorphic_selectable.alias(name=name)
         self.__adapter = sql_util.ClauseAdapter(alias,
-                                equivalents=self.__mapper._equivalent_columns)
+                                equivalents=self.__mapper._equivalent_columns,
+                                adapt_on_names=self.__adapt_on_names)
         self.__alias = alias
         # used to assign a name to the RowTuple object
         # returned by Query.
@@ -240,15 +286,18 @@ class AliasedClass(object):
         return {
             'mapper':self.__mapper, 
             'alias':self.__alias, 
-            'name':self._sa_label_name
+            'name':self._sa_label_name,
+            'adapt_on_names':self.__adapt_on_names,
         }
 
     def __setstate__(self, state):
         self.__mapper = state['mapper']
         self.__target = self.__mapper.class_
+        self.__adapt_on_names = state['adapt_on_names']
         alias = state['alias']
         self.__adapter = sql_util.ClauseAdapter(alias,
-                                equivalents=self.__mapper._equivalent_columns)
+                                equivalents=self.__mapper._equivalent_columns,
+                                adapt_on_names=self.__adapt_on_names)
         self.__alias = alias
         name = state['name']
         self._sa_label_name = name
@@ -300,11 +349,13 @@ class AliasedClass(object):
         return '<AliasedClass at 0x%x; %s>' % (
             id(self), self.__target.__name__)
 
-def aliased(element, alias=None, name=None):
+def aliased(element, alias=None, name=None, adapt_on_names=False):
     if isinstance(element, expression.FromClause):
+        if adapt_on_names:
+            raise sa_exc.ArgumentError("adapt_on_names only applies to ORM elements")
         return element.alias(name)
     else:
-        return AliasedClass(element, alias=alias, name=name)
+        return AliasedClass(element, alias=alias, name=name, adapt_on_names=adapt_on_names)
 
 def _orm_annotate(element, exclude=None):
     """Deep copy the given ClauseElement, annotating each element with the
index 61a95d764d7d19d9a6ac901108b0acbf867f945c..221a18e5195e1540b274ed397beff7e52299e1be 100644 (file)
@@ -675,18 +675,18 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
       s.c.col1 == table2.c.col1
 
     """
-    def __init__(self, selectable, equivalents=None, include=None, exclude=None):
+    def __init__(self, selectable, equivalents=None, include=None, exclude=None, adapt_on_names=False):
         self.__traverse_options__ = {'stop_on':[selectable]}
         self.selectable = selectable
         self.include = include
         self.exclude = exclude
         self.equivalents = util.column_dict(equivalents or {})
+        self.adapt_on_names = adapt_on_names
 
     def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET):
         newcol = self.selectable.corresponding_column(
                                     col, 
                                     require_embedded=require_embedded)
-
         if newcol is None and col in self.equivalents and col not in _seen:
             for equiv in self.equivalents[col]:
                 newcol = self._corresponding_column(equiv, 
@@ -694,6 +694,8 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
                                 _seen=_seen.union([col]))
                 if newcol is not None:
                     return newcol
+        if self.adapt_on_names and newcol is None:
+            newcol = self.selectable.c.get(col.name)
         return newcol
 
     def replace(self, col):
index b762721aff9e8323ac30dd6126cc11e1e1d40f43..e2fb55129d4f7dd87954a5d530ebcdfa62bc2382 100644 (file)
@@ -1576,6 +1576,37 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
             q = q.join(j)
             self.assert_compile(q, exp)
 
+    def test_aliased_adapt_on_names(self):
+        User, Address = self.classes.User, self.classes.Address
+
+        sess = Session()
+        agg_address = sess.query(Address.id, 
+                        func.sum(func.length(Address.email_address)).label('email_address')
+                        ).group_by(Address.user_id)
+        ag1 = aliased(Address, agg_address.subquery())
+        ag2 = aliased(Address, agg_address.subquery(), adapt_on_names=True)
+
+        # first, without adapt on names, 'email_address' isn't matched up - we get the raw "address"
+        # element in the SELECT
+        self.assert_compile(
+            sess.query(User, ag1.email_address).join(ag1, User.addresses).filter(ag1.email_address > 5),
+            "SELECT users.id AS users_id, users.name AS users_name, addresses.email_address "
+            "AS addresses_email_address FROM addresses, users JOIN "
+            "(SELECT addresses.id AS id, sum(length(addresses.email_address)) "
+            "AS email_address FROM addresses GROUP BY addresses.user_id) AS "
+            "anon_1 ON users.id = addresses.user_id WHERE addresses.email_address > :email_address_1"
+        )
+
+        # second, 'email_address' matches up to the aggreagte, and we get a smooth JOIN
+        # from users->subquery and that's it
+        self.assert_compile(
+            sess.query(User, ag2.email_address).join(ag2, User.addresses).filter(ag2.email_address > 5),
+            "SELECT users.id AS users_id, users.name AS users_name, "
+            "anon_1.email_address AS anon_1_email_address FROM users "
+            "JOIN (SELECT addresses.id AS id, sum(length(addresses.email_address)) "
+            "AS email_address FROM addresses GROUP BY addresses.user_id) AS "
+            "anon_1 ON users.id = addresses.user_id WHERE anon_1.email_address > :email_address_1",
+        )
 
 class SelectFromTest(QueryTest, AssertsCompiledSQL):
     run_setup_mappers = None
index 873be2b30ed139419c10fe78f9e9998db95574bb..4ce8104e454291c5aa6a77ef5ac101289a808830 100644 (file)
@@ -770,6 +770,16 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL):
             "ON users.id = addresses.user_id) AS anon_1"
         )
 
+    def test_aliased_sql_construct_raises_adapt_on_names(self):
+        User, Address = self.classes.User, self.classes.Address
+
+        j = join(User, Address)
+        assert_raises_message(
+            sa_exc.ArgumentError,
+            "adapt_on_names only applies to ORM elements",
+            aliased, j, adapt_on_names=True
+        )
+
     def test_scalar_subquery_compile_whereclause(self):
         User = self.classes.User
         Address = self.classes.Address