]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- a rudimental sharding (horizontal scaling) system is introduced. This system
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Jul 2007 04:05:55 +0000 (04:05 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Jul 2007 04:05:55 +0000 (04:05 +0000)
uses a modified Session which can distribute read and write operations among
multiple databases, based on user-defined functions defining the
"sharding strategy".  Instances and their dependents can be distributed
and queried among multiple databases based on attribute values, round-robin
approaches or any other user-defined system. [ticket:618]

CHANGES
examples/sharding/attribute_shard.py [new file with mode: 0644]
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/shard.py [new file with mode: 0644]
lib/sqlalchemy/orm/unitofwork.py

diff --git a/CHANGES b/CHANGES
index 092ee176ec3651a1fc4cde81725cff4dc5087558..668c94626c15aec4a2a19284d9f17e92944eeb60 100644 (file)
--- a/CHANGES
+++ b/CHANGES
 
         - added eagerload_all(), allows eagerload_all('x.y.z') to specify eager
           loading of all properties in the given path
+    
+    - a rudimental sharding (horizontal scaling) system is introduced.  This system
+      uses a modified Session which can distribute read and write operations among
+      multiple databases, based on user-defined functions defining the 
+      "sharding strategy".  Instances and their dependents can be distributed 
+      and queried among multiple databases based on attribute values, round-robin
+      approaches or any other user-defined system. [ticket:618]
           
     - Eager loading has been enhanced to allow even more joins in more places.
       It now functions at any arbitrary depth along self-referential 
@@ -76,7 +83,6 @@
       themselves are generated at compile time using a simple counting
       scheme now and are a lot easier on the eyes, as well as of course
       completely deterministic. [ticket:659]
-
       
     - added composite column properties. This allows you to create a 
       type which is represented by more than one column, when using the 
diff --git a/examples/sharding/attribute_shard.py b/examples/sharding/attribute_shard.py
new file mode 100644 (file)
index 0000000..1f80c54
--- /dev/null
@@ -0,0 +1,194 @@
+"""a basic example of using the SQLAlchemy Sharding API.
+Sharding refers to horizontally scaling data across multiple
+databases.
+
+In this example, four sqlite databases will store information about
+weather data on a database-per-continent basis.
+
+To set up a sharding system, you need:
+    1. multiple databases, each assined a 'shard id'
+    2. a function which can return a single shard id, given an instance
+    to be saved; this is called "shard_chooser"
+    3. a function which can return a list of shard ids which apply to a particular
+    instance identifier; this is called "id_chooser".  If it returns all shard ids,
+    all shards will be searched.
+    4. a function which can return a list of shard ids to try, given a particular 
+    Query ("query_chooser").  If it returns all shard ids, all shards will be 
+    queried and the results joined together.
+"""
+
+# step 1. imports
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from sqlalchemy.orm.shard import ShardedSession
+from sqlalchemy.sql import ColumnOperators
+import datetime, operator
+
+# step 2. databases
+echo = True
+db1 = create_engine('sqlite:///shard1.db', echo=echo)
+db2 = create_engine('sqlite:///shard2.db', echo=echo)
+db3 = create_engine('sqlite:///shard3.db', echo=echo)
+db4 = create_engine('sqlite:///shard4.db', echo=echo)
+
+
+# step 3. create session function.  this binds the shard ids
+# to databases within a ShardedSession and returns it.
+def create_session():
+    s = ShardedSession(shard_chooser, id_chooser, query_chooser)
+    s.bind_shard('north_america', db1)
+    s.bind_shard('asia', db2)
+    s.bind_shard('europe', db3)
+    s.bind_shard('south_america', db4)
+    return s
+
+# step 4.  table setup.
+meta = MetaData()
+
+# we need a way to create identifiers which are unique across all
+# databases.  one easy way would be to just use a composite primary key, where one
+# value is the shard id.  but here, we'll show something more "generic", an 
+# id generation function.  we'll use a simplistic "id table" stored in database
+# #1.  Any other method will do just as well; UUID, hilo, application-specific, etc.
+
+ids = Table('ids', meta,
+    Column('nextid', Integer, nullable=False))
+
+def id_generator(ctx):
+    # in reality, might want to use a separate transaction for this.
+    c = db1.connect()
+    nextid = c.execute(ids.select(for_update=True)).scalar()
+    c.execute(ids.update(values={ids.c.nextid : ids.c.nextid + 1}))
+    return nextid
+
+# table setup.  we'll store a lead table of continents/cities,
+# and a secondary table storing locations.
+# a particular row will be placed in the database whose shard id corresponds to the
+# 'continent'.  in this setup, secondary rows in 'weather_reports' will 
+# be placed in the same DB as that of the parent, but this can be changed
+# if you're willing to write more complex sharding functions.
+
+weather_locations = Table("weather_locations", meta,
+        Column('id', Integer, primary_key=True, default=id_generator),
+        Column('continent', String(30), nullable=False),
+        Column('city', String(50), nullable=False)
+    )
+
+weather_reports = Table("weather_reports", meta,
+    Column('id', Integer, primary_key=True),
+    Column('location_id', Integer, ForeignKey('weather_locations.id')),
+    Column('temperature', Numeric),
+    Column('report_time', DateTime, default=datetime.datetime.now),
+)
+
+# create tables
+for db in (db1, db2, db3, db4):
+    meta.drop_all(db)
+    meta.create_all(db)
+    
+# establish initial "id" in db1
+db1.execute(ids.insert(), nextid=1)
+
+
+# step 5. define sharding functions.  
+
+# we'll use a straight mapping of a particular set of "country" 
+# attributes to shard id.
+shard_lookup = {
+    'North America':'north_america',
+    'Asia':'asia',
+    'Europe':'europe',
+    'South America':'south_america'
+}
+
+# shard_chooser - looks at the given instance and returns a shard id
+# note that we need to define conditions for 
+# the WeatherLocation class, as well as our secondary Report class which will
+# point back to its WeatherLocation via its 'location' attribute.
+def shard_chooser(mapper, instance):
+    if isinstance(instance, WeatherLocation):
+        return shard_lookup[instance.continent]
+    else:
+        return shard_chooser(mapper, instance.location)
+
+# id_chooser.  given a primary key, returns a list of shards
+# to search.  here, we don't have any particular information from a
+# pk so we just return all shard ids. often, youd want to do some 
+# kind of round-robin strategy here so that requests are evenly 
+# distributed among DBs
+def id_chooser(ident):
+    return ['north_america', 'asia', 'europe', 'south_america']
+
+# query_chooser.  this also returns a list of shard ids, which can
+# just be all of them.  but here we'll search into the Query in order
+# to try to narrow down the list of shards to query.
+def query_chooser(query):
+    ids = []
+
+    # here we will traverse through the query's criterion, searching
+    # for SQL constructs.  we'll grab continent names as we find them
+    # and convert to shard ids
+    class FindContinent(sql.ClauseVisitor):
+        def visit_binary(self, binary):
+            if binary.left is weather_locations.c.continent:
+                if binary.operator == operator.eq:
+                    ids.append(shard_lookup[binary.right.value])
+                elif binary.operator == ColumnOperators.in_op:
+                    for bind in binary.right.clauses:
+                        ids.append(shard_lookup[bind.value])
+                    
+    FindContinent().traverse(query._criterion)
+    if len(ids) == 0:
+        return ['north_america', 'asia', 'europe', 'south_america']
+    else:
+        return ids
+
+# step 6.  mapped classes.    
+class WeatherLocation(object):
+    def __init__(self, continent, city):
+        self.continent = continent
+        self.city = city
+
+class Report(object):
+    def __init__(self, temperature):
+        self.temperature = temperature
+
+# step 7.  mappers
+mapper(WeatherLocation, weather_locations, properties={
+    'reports':relation(Report, backref='location')
+})
+
+mapper(Report, weather_reports)    
+
+
+# save and load objects!
+
+tokyo = WeatherLocation('Asia', 'Tokyo')
+newyork = WeatherLocation('North America', 'New York')
+toronto = WeatherLocation('North America', 'Toronto')
+london = WeatherLocation('Europe', 'London')
+dublin = WeatherLocation('Europe', 'Dublin')
+brasilia = WeatherLocation('South America', 'Brasila')
+quito = WeatherLocation('South America', 'Quito')
+
+tokyo.reports.append(Report(80.0))
+newyork.reports.append(Report(75))
+quito.reports.append(Report(85))
+
+sess = create_session()
+for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
+    sess.save(c)
+sess.flush()
+
+sess.clear()
+
+t = sess.query(WeatherLocation).get(tokyo.id)
+assert t.city == tokyo.city
+assert t.reports[0].temperature == 80.0
+
+north_american_cities = sess.query(WeatherLocation).filter(WeatherLocation.continent == 'North America')
+assert [c.city for c in north_american_cities] == ['New York', 'Toronto']
+
+asia_and_europe = sess.query(WeatherLocation).filter(WeatherLocation.continent.in_('Europe', 'Asia'))
+assert set([c.city for c in asia_and_europe]) == set(['Tokyo', 'London', 'Dublin'])
+
index 62b4d7d8960e13625058c1647af09bd2289e1285..e035537351846dfa2a257b0ba32f467c70e277f1 100644 (file)
@@ -1095,10 +1095,15 @@ class Mapper(object):
                 self.save_obj([obj], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
             return
 
-        connection = uowtransaction.transaction.connection(self)
-
+        if 'connection_callable' in uowtransaction.mapper_flush_opts:
+            connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
+            tups = [(obj, connection_callable(self, obj)) for obj in objects]
+        else:
+            connection = uowtransaction.transaction.connection(self)
+            tups = [(obj, connection) for obj in objects]
+            
         if not postupdate:
-            for obj in objects:
+            for obj, connection in tups:
                 if not has_identity(obj):
                     for mapper in object_mapper(obj).iterate_to_root():
                         mapper.extension.before_insert(mapper, connection, obj)
@@ -1106,7 +1111,7 @@ class Mapper(object):
                     for mapper in object_mapper(obj).iterate_to_root():
                         mapper.extension.before_update(mapper, connection, obj)
 
-        for obj in objects:
+        for obj, connection in tups:
             # detect if we have a "pending" instance (i.e. has no instance_key attached to it),
             # and another instance with the same identity key already exists as persistent.  convert to an
             # UPDATE if so.
@@ -1137,7 +1142,7 @@ class Mapper(object):
             insert = []
             update = []
 
-            for obj in objects:
+            for obj, connection in tups:
                 mapper = object_mapper(obj)
                 if table not in mapper.tables or not mapper._has_pks(table):
                     continue
@@ -1215,9 +1220,9 @@ class Mapper(object):
                     if hasdata:
                         # if none of the attributes changed, dont even
                         # add the row to be updated.
-                        update.append((obj, params, mapper))
+                        update.append((obj, params, mapper, connection))
                 else:
-                    insert.append((obj, params, mapper))
+                    insert.append((obj, params, mapper, connection))
 
             if len(update):
                 mapper = table_to_mapper[table]
@@ -1237,11 +1242,11 @@ class Mapper(object):
                     return 0
                 update.sort(comparator)
                 for rec in update:
-                    (obj, params, mapper) = rec
+                    (obj, params, mapper, connection) = rec
                     c = connection.execute(statement, params)
                     mapper._postfetch(connection, table, obj, c, c.last_updated_params())
 
-                    updated_objects.add(obj)
+                    updated_objects.add((obj, connection))
                     rows += c.rowcount
 
                 if c.supports_sane_rowcount() and rows != len(update):
@@ -1253,7 +1258,7 @@ class Mapper(object):
                     return cmp(a[0]._sa_insert_order, b[0]._sa_insert_order)
                 insert.sort(comparator)
                 for rec in insert:
-                    (obj, params, mapper) = rec
+                    (obj, params, mapper, connection) = rec
                     c = connection.execute(statement, params)
                     primary_key = c.last_inserted_ids()
                     if primary_key is not None:
@@ -1275,12 +1280,12 @@ class Mapper(object):
                             mapper._synchronizer.execute(obj, obj)
                     sync(mapper)
 
-                    inserted_objects.add(obj)
+                    inserted_objects.add((obj, connection))
         if not postupdate:
-            for obj in inserted_objects:
+            for obj, connection in inserted_objects:
                 for mapper in object_mapper(obj).iterate_to_root():
                     mapper.extension.after_insert(mapper, connection, obj)
-            for obj in updated_objects:
+            for obj, connection in updated_objects:
                 for mapper in object_mapper(obj).iterate_to_root():
                     mapper.extension.after_update(mapper, connection, obj)
 
@@ -1320,9 +1325,14 @@ class Mapper(object):
         if self.__should_log_debug:
             self.__log_debug("delete_obj() start")
 
-        connection = uowtransaction.transaction.connection(self)
+        if 'connection_callable' in uowtransaction.mapper_flush_opts:
+            connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
+            tups = [(obj, connection_callable(self, obj)) for obj in objects]
+        else:
+            connection = uowtransaction.transaction.connection(self)
+            tups = [(obj, connection) for obj in objects]
 
-        for obj in objects:
+        for (obj, connection) in tups:
             for mapper in object_mapper(obj).iterate_to_root():
                 mapper.extension.before_delete(mapper, connection, obj)
         
@@ -1333,8 +1343,8 @@ class Mapper(object):
                 table_to_mapper.setdefault(t, mapper)
 
         for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=True):
-            delete = []
-            for obj in objects:
+            delete = {}
+            for (obj, connection) in tups:
                 mapper = object_mapper(obj)
                 if table not in mapper.tables or not mapper._has_pks(table):
                     continue
@@ -1343,13 +1353,13 @@ class Mapper(object):
                 if not hasattr(obj, '_instance_key'):
                     continue
                 else:
-                    delete.append(params)
+                    delete.setdefault(connection, []).append(params)
                 for col in mapper.pks_by_table[table]:
                     params[col.key] = mapper.get_attr_by_column(obj, col)
                 if mapper.version_id_col is not None:
                     params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col)
-                deleted_objects.add(obj)
-            if len(delete):
+                deleted_objects.add((obj, connection))
+            for connection, del_objects in delete.iteritems():
                 mapper = table_to_mapper[table]
                 def comparator(a, b):
                     for col in mapper.pks_by_table[table]:
@@ -1357,18 +1367,18 @@ class Mapper(object):
                         if x != 0:
                             return x
                     return 0
-                delete.sort(comparator)
+                del_objects.sort(comparator)
                 clause = sql.and_()
                 for col in mapper.pks_by_table[table]:
                     clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True))
                 if mapper.version_id_col is not None:
                     clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True))
                 statement = table.delete(clause)
-                c = connection.execute(statement, delete)
-                if c.supports_sane_rowcount() and c.rowcount != len(delete):
+                c = connection.execute(statement, del_objects)
+                if c.supports_sane_rowcount() and c.rowcount != len(del_objects):
                     raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.rowcount, len(delete)))
 
-        for obj in deleted_objects:
+        for obj, connection in deleted_objects:
             for mapper in object_mapper(obj).iterate_to_root():
                 mapper.extension.after_delete(mapper, connection, obj)
 
index 86e6d522f3e43fbc4c709b1d547e2449109bc4ed..68a71e56552774ab9a8ddc31e8a0523ebebaf7d9 100644 (file)
@@ -81,7 +81,7 @@ class Query(object):
         key = self.mapper.identity_key_from_primary_key(ident)
         return self._get(key, ident, **kwargs)
 
-    def load(self, ident, **kwargs):
+    def load(self, ident, raiseerr=True, **kwargs):
         """Return an instance of the object based on the given
         identifier.
 
@@ -97,7 +97,7 @@ class Query(object):
             return ret
         key = self.mapper.identity_key_from_primary_key(ident)
         instance = self._get(key, ident, reload=True, **kwargs)
-        if instance is None:
+        if instance is None and raiseerr:
             raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
         return instance
         
@@ -608,13 +608,15 @@ class Query(object):
         statement.use_labels = True
         if self.session.autoflush:
             self.session.flush()
+        return self._execute_and_instances(statement)
+    
+    def _execute_and_instances(self, statement):
         result = self.session.execute(statement, params=self._params, mapper=self.mapper)
         try:
             return iter(self.instances(result))
         finally:
             result.close()
 
-
     def instances(self, cursor, *mappers_or_columns, **kwargs):
         """Return a list of mapped instances corresponding to the rows
         in a given *cursor* (i.e. ``ResultProxy``).
index d41eac2b581bfe908479615d27fd3d0cff735395..6b5c4a0725ee29d953a0213bd5b5656bda27fa85 100644 (file)
@@ -28,10 +28,10 @@ class SessionTransaction(object):
         self.autoflush = autoflush
         self.nested = nested
 
-    def connection(self, mapper_or_class, entity_name=None):
+    def connection(self, mapper_or_class, entity_name=None, **kwargs):
         if isinstance(mapper_or_class, type):
             mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name)
-        engine = self.session.get_bind(mapper_or_class)
+        engine = self.session.get_bind(mapper_or_class, **kwargs)
         return self.get_or_add(engine)
 
     def _begin(self, **kwargs):
@@ -137,7 +137,7 @@ class Session(object):
         self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map)
 
         self.bind = bind
-        self.binds = {}
+        self.__binds = {}
         self.echo_uow = echo_uow
         self.weak_identity_map = weak_identity_map
         self.transaction = None
@@ -145,6 +145,8 @@ class Session(object):
         self.autoflush = autoflush
         self.transactional = transactional or autoflush
         self.twophase = twophase
+        self._query_cls = query.Query
+        self._mapper_flush_opts = {}
         if self.transactional:
             self.begin()
         _sessions[self.hash_key] = self
@@ -185,7 +187,7 @@ class Session(object):
             self.transaction = self.transaction.commit()
         if self.transaction is None and self.transactional:
             self.begin()
-        
+    
     def connection(self, mapper=None, **kwargs):
         """Return a ``Connection`` corresponding to this session's
         transactional context, if any.
@@ -263,7 +265,7 @@ class Session(object):
         if isinstance(mapper, type):
             mapper = _class_mapper(mapper, entity_name=entity_name)
 
-        self.binds[mapper] = bind
+        self.__binds[mapper] = bind
 
     def bind_table(self, table, bind):
         """Bind the given `table` to the given ``Engine`` or ``Connection``.
@@ -272,7 +274,7 @@ class Session(object):
         given `bind`.
         """
 
-        self.binds[table] = bind
+        self.__binds[table] = bind
 
     def get_bind(self, mapper):
         """Return the ``Engine`` or ``Connection`` which is used to execute
@@ -306,10 +308,10 @@ class Session(object):
                 return self.bind
             else:
                 raise exceptions.InvalidRequestError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()")
-        elif self.binds.has_key(mapper):
-            return self.binds[mapper]
-        elif self.binds.has_key(mapper.mapped_table):
-            return self.binds[mapper.mapped_table]
+        elif self.__binds.has_key(mapper):
+            return self.__binds[mapper]
+        elif self.__binds.has_key(mapper.mapped_table):
+            return self.__binds[mapper.mapped_table]
         elif self.bind is not None:
             return self.bind
         else:
@@ -326,9 +328,9 @@ class Session(object):
         entity_name = kwargs.pop('entity_name', None)
         
         if isinstance(mapper_or_class, type):
-            q = query.Query(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs)
+            q = self._query_cls(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs)
         else:
-            q = query.Query(mapper_or_class, self, **kwargs)
+            q = self._query_cls(mapper_or_class, self, **kwargs)
             
         for ent in addtl_entities:
             q = q.add_entity(ent)
diff --git a/lib/sqlalchemy/orm/shard.py b/lib/sqlalchemy/orm/shard.py
new file mode 100644 (file)
index 0000000..cc13f8c
--- /dev/null
@@ -0,0 +1,112 @@
+from sqlalchemy.orm.session import Session
+from sqlalchemy.orm import Query
+
+class ShardedSession(Session):
+    def __init__(self, shard_chooser, id_chooser, query_chooser, **kwargs):
+        """construct a ShardedSession.
+        
+            shard_chooser
+                a callable which, passed a Mapper and a mapped instance, returns a
+                shard ID.  this id may be based off of the attributes present within the
+                object, or on some round-robin scheme.  If the scheme is based on a
+                selection, it should set whatever state on the instance to mark it in
+                the future as participating in that shard.
+            
+            id_chooser
+                a callable, passed a tuple of identity values, which should return
+                a list of shard ids where the ID might reside.  The databases will
+                be queried in the order of this listing.
+                
+            query_chooser
+                for a given Query, returns the list of shard_ids where the query
+                should be issued.  Results from all shards returned will be 
+                combined together into a single listing.
+        
+        """
+        super(ShardedSession, self).__init__(**kwargs)
+        self.shard_chooser = shard_chooser
+        self.id_chooser = id_chooser
+        self.query_chooser = query_chooser
+        self.__binds = {}
+        self._mapper_flush_opts = {'connection_callable':self.connection}
+        self._query_cls = ShardedQuery
+        
+    def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
+        if shard_id is None:
+            shard_id = self.shard_chooser(mapper, instance)
+
+        if self.transaction is not None:
+            return self.transaction.connection(mapper, shard_id=shard_id)
+        else:
+            return self.get_bind(mapper, shard_id=shard_id, instance=instance).contextual_connect(**kwargs)
+    
+    def get_bind(self, mapper, shard_id=None, instance=None):
+        if shard_id is None:
+            shard_id = self.shard_chooser(mapper, instance)
+        return self.__binds[shard_id]
+
+    def bind_shard(self, shard_id, bind):
+        self.__binds[shard_id] = bind
+
+class ShardedQuery(Query):
+    def __init__(self, *args, **kwargs):
+        super(ShardedQuery, self).__init__(*args, **kwargs)
+        self.id_chooser = self.session.id_chooser
+        self.query_chooser = self.session.query_chooser
+        self._shard_id = None
+        
+    def _clone(self):
+        q = ShardedQuery.__new__(ShardedQuery)
+        q.__dict__ = self.__dict__.copy()
+        return q
+    
+    def set_shard(self, shard_id):
+        """return a new query, limited to a single shard ID.
+        
+        all subsequent operations with the returned query will 
+        be against the single shard regardless of other state.
+        """
+        
+        q = self._clone()
+        q._shard_id = shard_id
+        return q
+        
+    def _execute_and_instances(self, statement):
+        if self._shard_id is not None:
+            result = self.session.connection(mapper=self.mapper, shard_id=self._shard_id).execute(statement, **self._params)
+            try:
+                return iter(self.instances(result))
+            finally:
+                result.close()
+        else:
+            partial = []
+            for shard_id in self.query_chooser(self):
+                result = self.session.connection(mapper=self.mapper, shard_id=shard_id).execute(statement, **self._params)
+                try:
+                    partial = partial + list(self.instances(result))
+                finally:
+                    result.close()
+            # if some kind of in memory 'sorting' were done, this is where it would happen
+            return iter(partial)
+
+    def get(self, ident, **kwargs):
+        if self._shard_id is not None:
+            return super(ShardedQuery, self).get(ident)
+        else:
+            for shard_id in self.id_chooser(ident):
+                o = self.set_shard(shard_id).get(ident, **kwargs)
+                if o is not None:
+                    return o
+            else:
+                return None
+    
+    def load(self, ident, **kwargs):
+        if self._shard_id is not None:
+            return super(ShardedQuery, self).load(ident)
+        else:
+            for shard_id in self.id_chooser(ident):
+                o = self.set_shard(shard_id).load(ident, raiseerr=False, **kwargs)
+                if o is not None:
+                    return o
+            else:
+                raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
index 3f41d3f35b61aa0a4b8bb3c01d1371901016ebce..f59042810a79d3e8b1ff9bac66d314d9515c9504 100644 (file)
@@ -224,6 +224,7 @@ class UOWTransaction(object):
     def __init__(self, uow, session):
         self.uow = uow
         self.session = session
+        self.mapper_flush_opts = session._mapper_flush_opts
         
         # stores tuples of mapper/dependent mapper pairs,
         # representing a partial ordering fed into topological sort