String, ForeignKey, Float, DateTime)
from sqlalchemy.orm import sessionmaker, mapper, relationship
from sqlalchemy.ext.horizontal_shard import ShardedSession
-from sqlalchemy.sql import operators
-from sqlalchemy import sql
+from sqlalchemy.sql import operators, visitors
+
import datetime
# step 2. databases
'South America':'south_america'
}
-# shard_chooser - looks at the given instance and returns a shard id
-# note that we need to define conditions for
-# the WeatherLocation class, as well as our secondary Report class which will
-# point back to its WeatherLocation via its 'location' attribute.
def shard_chooser(mapper, instance, clause=None):
+ """shard chooser.
+
+ looks at the given instance and returns a shard id
+ note that we need to define conditions for
+ the WeatherLocation class, as well as our secondary Report class which will
+ point back to its WeatherLocation via its 'location' attribute.
+
+ """
if isinstance(instance, WeatherLocation):
return shard_lookup[instance.continent]
else:
return shard_chooser(mapper, instance.location)
-# id_chooser. given a primary key, returns a list of shards
-# to search. here, we don't have any particular information from a
-# pk so we just return all shard ids. often, youd want to do some
-# kind of round-robin strategy here so that requests are evenly
-# distributed among DBs
def id_chooser(query, ident):
+ """id chooser.
+
+ given a primary key, returns a list of shards
+ to search. here, we don't have any particular information from a
+ pk so we just return all shard ids. often, youd want to do some
+ kind of round-robin strategy here so that requests are evenly
+ distributed among DBs.
+
+ """
return ['north_america', 'asia', 'europe', 'south_america']
-# query_chooser. this also returns a list of shard ids, which can
-# just be all of them. but here we'll search into the Query in order
-# to try to narrow down the list of shards to query.
def query_chooser(query):
+ """query chooser.
+
+ this also returns a list of shard ids, which can
+ just be all of them. but here we'll search into the Query in order
+ to try to narrow down the list of shards to query.
+
+ """
ids = []
- # here we will traverse through the query's criterion, searching
- # for SQL constructs. we'll grab continent names as we find them
+ # we'll grab continent names as we find them
# and convert to shard ids
- class FindContinent(sql.ClauseVisitor):
- def visit_binary(self, binary):
- # "shares_lineage()" returns True if both columns refer to the same
- # statement column, adjusting for any annotations present.
- # (an annotation is an internal clone of a Column object
- # and occur when using ORM-mapped attributes like
- # "WeatherLocation.continent"). A simpler comparison, though less accurate,
- # would be "binary.left.key == 'continent'".
- if binary.left.shares_lineage(weather_locations.c.continent):
- if binary.operator == operators.eq:
- ids.append(shard_lookup[binary.right.value])
- elif binary.operator == operators.in_op:
- for bind in binary.right.clauses:
- ids.append(shard_lookup[bind.value])
+ for column, operator, value in _get_query_comparisons(query):
+ # "shares_lineage()" returns True if both columns refer to the same
+ # statement column, adjusting for any annotations present.
+ # (an annotation is an internal clone of a Column object
+ # and occur when using ORM-mapped attributes like
+ # "WeatherLocation.continent"). A simpler comparison, though less accurate,
+ # would be "column.key == 'continent'".
+ if column.shares_lineage(weather_locations.c.continent):
+ if operator == operators.eq:
+ ids.append(shard_lookup[value])
+ elif operator == operators.in_op:
+ ids.extend(shard_lookup[v] for v in value)
- FindContinent().traverse(query._criterion)
if len(ids) == 0:
return ['north_america', 'asia', 'europe', 'south_america']
else:
return ids
+def _get_query_comparisons(query):
+ """Search an orm.Query object for binary expressions.
+
+ Returns expressions which match a Column against one or more
+ literal values as a list of tuples of the form
+ (column, operator, values). "values" is a single value
+ or tuple of values depending on the operator.
+
+ """
+ binds = {}
+ clauses = set()
+ comparisons = []
+
+ def visit_bindparam(bind):
+ # visit a bind parameter. Below we ensure
+ # that we get the value whether it was specified
+ # as part of query.params(), or is directly embedded
+ # in the bind's "value" attribute.
+ value = query._params.get(bind.key, bind.value)
+
+ # some ORM functions place the bind's value as a
+ # callable for deferred evaulation. Get that
+ # actual value here.
+ if callable(value):
+ value = value()
+
+ binds[bind] = value
+
+ def visit_column(column):
+ clauses.add(column)
+
+ def visit_binary(binary):
+ # special handling for "col IN (params)"
+ if binary.left in clauses and \
+ binary.operator == operators.in_op and \
+ hasattr(binary.right, 'clauses'):
+ comparisons.append(
+ (binary.left, binary.operator,
+ tuple(binds[bind] for bind in binary.right.clauses)
+ )
+ )
+ elif binary.left in clauses and binary.right in binds:
+ comparisons.append(
+ (binary.left, binary.operator,binds[binary.right])
+ )
+
+ elif binary.left in binds and binary.right in clauses:
+ comparisons.append(
+ (binary.right, binary.operator,binds[binary.left])
+ )
+
+ # here we will traverse through the query's criterion, searching
+ # for SQL constructs. We will place simple column comparisons
+ # into a list.
+ if query._criterion is not None:
+ visitors.traverse_depthfirst(query._criterion, {},
+ {'bindparam':visit_bindparam,
+ 'binary':visit_binary,
+ 'column':visit_column
+ }
+ )
+ return comparisons
+
# further configure create_session to use these functions
-create_session.configure(shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser)
+create_session.configure(
+ shard_chooser=shard_chooser,
+ id_chooser=id_chooser,
+ query_chooser=query_chooser
+ )
# step 6. mapped classes.
class WeatherLocation(object):