]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Updated attribute_shard.py example to use a more robust
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Apr 2010 17:10:55 +0000 (13:10 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Apr 2010 17:10:55 +0000 (13:10 -0400)
method of searching a Query for binary expressions which
compare columns against literal values.

CHANGES
examples/sharding/attribute_shard.py

diff --git a/CHANGES b/CHANGES
index fc217179c6eb07cbc36171a9ab2b48a1fb795053..4831295913b5017faced17731071a9508e210e0e 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -44,15 +44,20 @@ CHANGES
     has been added.  [ticket:1755]
     
 - ext
-   - the compiler extension now allows @compiles decorators
-     on base classes that extend to child classes, @compiles
-     decorators on child classes that aren't broken by a 
-     @compiles decorator on the base class.
+  - the compiler extension now allows @compiles decorators
+    on base classes that extend to child classes, @compiles
+    decorators on child classes that aren't broken by a 
+    @compiles decorator on the base class.
    
-   - Declarative will raise an informative error message
-     if a non-mapped class attribute is referenced in the
-     string-based relationship() arguments.
-     
+  - Declarative will raise an informative error message
+    if a non-mapped class attribute is referenced in the
+    string-based relationship() arguments.
+
+- examples
+  - Updated attribute_shard.py example to use a more robust
+    method of searching a Query for binary expressions which
+    compare columns against literal values.
+    
 0.6beta3
 ========
 
index 07f76c30919cd783ef7d58fe1f98953420297679..1a39f5de32c488c1917936547a0035ebafb93869 100644 (file)
@@ -4,8 +4,8 @@ from sqlalchemy import (create_engine, MetaData, Table, Column, Integer,
     String, ForeignKey, Float, DateTime)
 from sqlalchemy.orm import sessionmaker, mapper, relationship
 from sqlalchemy.ext.horizontal_shard import ShardedSession
-from sqlalchemy.sql import operators
-from sqlalchemy import sql
+from sqlalchemy.sql import operators, visitors
+
 import datetime
 
 # step 2. databases
@@ -87,56 +87,131 @@ shard_lookup = {
     'South America':'south_america'
 }
 
-# shard_chooser - looks at the given instance and returns a shard id
-# note that we need to define conditions for 
-# the WeatherLocation class, as well as our secondary Report class which will
-# point back to its WeatherLocation via its 'location' attribute.
 def shard_chooser(mapper, instance, clause=None):
+    """shard chooser.
+    
+    looks at the given instance and returns a shard id
+    note that we need to define conditions for 
+    the WeatherLocation class, as well as our secondary Report class which will
+    point back to its WeatherLocation via its 'location' attribute.
+    
+    """
     if isinstance(instance, WeatherLocation):
         return shard_lookup[instance.continent]
     else:
         return shard_chooser(mapper, instance.location)
 
-# id_chooser.  given a primary key, returns a list of shards
-# to search.  here, we don't have any particular information from a
-# pk so we just return all shard ids. often, youd want to do some 
-# kind of round-robin strategy here so that requests are evenly 
-# distributed among DBs
 def id_chooser(query, ident):
+    """id chooser.  
+    
+    given a primary key, returns a list of shards
+    to search.  here, we don't have any particular information from a
+    pk so we just return all shard ids. often, youd want to do some 
+    kind of round-robin strategy here so that requests are evenly 
+    distributed among DBs.
+    
+    """
     return ['north_america', 'asia', 'europe', 'south_america']
 
-# query_chooser.  this also returns a list of shard ids, which can
-# just be all of them.  but here we'll search into the Query in order
-# to try to narrow down the list of shards to query.
 def query_chooser(query):
+    """query chooser.
+    
+    this also returns a list of shard ids, which can
+    just be all of them.  but here we'll search into the Query in order
+    to try to narrow down the list of shards to query.
+    
+    """
     ids = []
 
-    # here we will traverse through the query's criterion, searching
-    # for SQL constructs.  we'll grab continent names as we find them
+    # we'll grab continent names as we find them
     # and convert to shard ids
-    class FindContinent(sql.ClauseVisitor):
-        def visit_binary(self, binary):
-            # "shares_lineage()" returns True if both columns refer to the same
-            # statement column, adjusting for any annotations present.
-            # (an annotation is an internal clone of a Column object
-            # and occur when using ORM-mapped attributes like 
-            # "WeatherLocation.continent"). A simpler comparison, though less accurate, 
-            # would be "binary.left.key == 'continent'".
-            if binary.left.shares_lineage(weather_locations.c.continent):
-                if binary.operator == operators.eq:
-                    ids.append(shard_lookup[binary.right.value])
-                elif binary.operator == operators.in_op:
-                    for bind in binary.right.clauses:
-                        ids.append(shard_lookup[bind.value])
+    for column, operator, value in _get_query_comparisons(query):
+        # "shares_lineage()" returns True if both columns refer to the same
+        # statement column, adjusting for any annotations present.
+        # (an annotation is an internal clone of a Column object
+        # and occur when using ORM-mapped attributes like 
+        # "WeatherLocation.continent"). A simpler comparison, though less accurate, 
+        # would be "column.key == 'continent'".
+        if column.shares_lineage(weather_locations.c.continent):
+            if operator == operators.eq:
+                ids.append(shard_lookup[value])
+            elif operator == operators.in_op:
+                ids.extend(shard_lookup[v] for v in value)
                     
-    FindContinent().traverse(query._criterion)
     if len(ids) == 0:
         return ['north_america', 'asia', 'europe', 'south_america']
     else:
         return ids
 
+def _get_query_comparisons(query):
+    """Search an orm.Query object for binary expressions.
+    
+    Returns expressions which match a Column against one or more
+    literal values as a list of tuples of the form 
+    (column, operator, values).   "values" is a single value
+    or tuple of values depending on the operator.
+    
+    """
+    binds = {}
+    clauses = set()
+    comparisons = []
+
+    def visit_bindparam(bind):
+        # visit a bind parameter.   Below we ensure
+        # that we get the value whether it was specified
+        # as part of query.params(), or is directly embedded
+        # in the bind's "value" attribute.
+        value = query._params.get(bind.key, bind.value)
+
+        # some ORM functions place the bind's value as a 
+        # callable for deferred evaulation.   Get that
+        # actual value here.
+        if callable(value):
+            value = value()
+
+        binds[bind] = value
+
+    def visit_column(column):
+        clauses.add(column)
+
+    def visit_binary(binary):
+        # special handling for "col IN (params)"
+        if binary.left in clauses and \
+                binary.operator == operators.in_op and \
+                hasattr(binary.right, 'clauses'):
+            comparisons.append(
+                (binary.left, binary.operator, 
+                    tuple(binds[bind] for bind in binary.right.clauses)
+                )
+            )
+        elif binary.left in clauses and binary.right in binds:
+            comparisons.append(
+                (binary.left, binary.operator,binds[binary.right])
+            )
+
+        elif binary.left in binds and binary.right in clauses:
+            comparisons.append(
+                (binary.right, binary.operator,binds[binary.left])
+            )
+
+    # here we will traverse through the query's criterion, searching
+    # for SQL constructs.  We will place simple column comparisons
+    # into a list.
+    if query._criterion is not None:
+        visitors.traverse_depthfirst(query._criterion, {},
+                    {'bindparam':visit_bindparam,
+                        'binary':visit_binary,
+                        'column':visit_column
+                    }
+        )
+    return comparisons
+
 # further configure create_session to use these functions
-create_session.configure(shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser)
+create_session.configure(
+                    shard_chooser=shard_chooser, 
+                    id_chooser=id_chooser, 
+                    query_chooser=query_chooser
+                    )
 
 # step 6.  mapped classes.    
 class WeatherLocation(object):