- added eagerload_all(), allows eagerload_all('x.y.z') to specify eager
loading of all properties in the given path
+
+ - a rudimental sharding (horizontal scaling) system is introduced. This system
+ uses a modified Session which can distribute read and write operations among
+ multiple databases, based on user-defined functions defining the
+ "sharding strategy". Instances and their dependents can be distributed
+ and queried among multiple databases based on attribute values, round-robin
+ approaches or any other user-defined system. [ticket:618]
- Eager loading has been enhanced to allow even more joins in more places.
It now functions at any arbitrary depth along self-referential
themselves are generated at compile time using a simple counting
scheme now and are a lot easier on the eyes, as well as of course
completely deterministic. [ticket:659]
-
- added composite column properties. This allows you to create a
type which is represented by more than one column, when using the
--- /dev/null
+"""a basic example of using the SQLAlchemy Sharding API.
+Sharding refers to horizontally scaling data across multiple
+databases.
+
+In this example, four sqlite databases will store information about
+weather data on a database-per-continent basis.
+
+To set up a sharding system, you need:
+ 1. multiple databases, each assined a 'shard id'
+ 2. a function which can return a single shard id, given an instance
+ to be saved; this is called "shard_chooser"
+ 3. a function which can return a list of shard ids which apply to a particular
+ instance identifier; this is called "id_chooser". If it returns all shard ids,
+ all shards will be searched.
+ 4. a function which can return a list of shard ids to try, given a particular
+ Query ("query_chooser"). If it returns all shard ids, all shards will be
+ queried and the results joined together.
+"""
+
+# step 1. imports
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from sqlalchemy.orm.shard import ShardedSession
+from sqlalchemy.sql import ColumnOperators
+import datetime, operator
+
+# step 2. databases
+echo = True
+db1 = create_engine('sqlite:///shard1.db', echo=echo)
+db2 = create_engine('sqlite:///shard2.db', echo=echo)
+db3 = create_engine('sqlite:///shard3.db', echo=echo)
+db4 = create_engine('sqlite:///shard4.db', echo=echo)
+
+
+# step 3. create session function. this binds the shard ids
+# to databases within a ShardedSession and returns it.
+def create_session():
+ s = ShardedSession(shard_chooser, id_chooser, query_chooser)
+ s.bind_shard('north_america', db1)
+ s.bind_shard('asia', db2)
+ s.bind_shard('europe', db3)
+ s.bind_shard('south_america', db4)
+ return s
+
+# step 4. table setup.
+meta = MetaData()
+
+# we need a way to create identifiers which are unique across all
+# databases. one easy way would be to just use a composite primary key, where one
+# value is the shard id. but here, we'll show something more "generic", an
+# id generation function. we'll use a simplistic "id table" stored in database
+# #1. Any other method will do just as well; UUID, hilo, application-specific, etc.
+
+ids = Table('ids', meta,
+ Column('nextid', Integer, nullable=False))
+
+def id_generator(ctx):
+ # in reality, might want to use a separate transaction for this.
+ c = db1.connect()
+ nextid = c.execute(ids.select(for_update=True)).scalar()
+ c.execute(ids.update(values={ids.c.nextid : ids.c.nextid + 1}))
+ return nextid
+
+# table setup. we'll store a lead table of continents/cities,
+# and a secondary table storing locations.
+# a particular row will be placed in the database whose shard id corresponds to the
+# 'continent'. in this setup, secondary rows in 'weather_reports' will
+# be placed in the same DB as that of the parent, but this can be changed
+# if you're willing to write more complex sharding functions.
+
+weather_locations = Table("weather_locations", meta,
+ Column('id', Integer, primary_key=True, default=id_generator),
+ Column('continent', String(30), nullable=False),
+ Column('city', String(50), nullable=False)
+ )
+
+weather_reports = Table("weather_reports", meta,
+ Column('id', Integer, primary_key=True),
+ Column('location_id', Integer, ForeignKey('weather_locations.id')),
+ Column('temperature', Numeric),
+ Column('report_time', DateTime, default=datetime.datetime.now),
+)
+
+# create tables
+for db in (db1, db2, db3, db4):
+ meta.drop_all(db)
+ meta.create_all(db)
+
+# establish initial "id" in db1
+db1.execute(ids.insert(), nextid=1)
+
+
+# step 5. define sharding functions.
+
+# we'll use a straight mapping of a particular set of "country"
+# attributes to shard id.
+shard_lookup = {
+ 'North America':'north_america',
+ 'Asia':'asia',
+ 'Europe':'europe',
+ 'South America':'south_america'
+}
+
+# shard_chooser - looks at the given instance and returns a shard id
+# note that we need to define conditions for
+# the WeatherLocation class, as well as our secondary Report class which will
+# point back to its WeatherLocation via its 'location' attribute.
+def shard_chooser(mapper, instance):
+ if isinstance(instance, WeatherLocation):
+ return shard_lookup[instance.continent]
+ else:
+ return shard_chooser(mapper, instance.location)
+
+# id_chooser. given a primary key, returns a list of shards
+# to search. here, we don't have any particular information from a
+# pk so we just return all shard ids. often, youd want to do some
+# kind of round-robin strategy here so that requests are evenly
+# distributed among DBs
+def id_chooser(ident):
+ return ['north_america', 'asia', 'europe', 'south_america']
+
+# query_chooser. this also returns a list of shard ids, which can
+# just be all of them. but here we'll search into the Query in order
+# to try to narrow down the list of shards to query.
+def query_chooser(query):
+ ids = []
+
+ # here we will traverse through the query's criterion, searching
+ # for SQL constructs. we'll grab continent names as we find them
+ # and convert to shard ids
+ class FindContinent(sql.ClauseVisitor):
+ def visit_binary(self, binary):
+ if binary.left is weather_locations.c.continent:
+ if binary.operator == operator.eq:
+ ids.append(shard_lookup[binary.right.value])
+ elif binary.operator == ColumnOperators.in_op:
+ for bind in binary.right.clauses:
+ ids.append(shard_lookup[bind.value])
+
+ FindContinent().traverse(query._criterion)
+ if len(ids) == 0:
+ return ['north_america', 'asia', 'europe', 'south_america']
+ else:
+ return ids
+
+# step 6. mapped classes.
+class WeatherLocation(object):
+ def __init__(self, continent, city):
+ self.continent = continent
+ self.city = city
+
+class Report(object):
+ def __init__(self, temperature):
+ self.temperature = temperature
+
+# step 7. mappers
+mapper(WeatherLocation, weather_locations, properties={
+ 'reports':relation(Report, backref='location')
+})
+
+mapper(Report, weather_reports)
+
+
+# save and load objects!
+
+tokyo = WeatherLocation('Asia', 'Tokyo')
+newyork = WeatherLocation('North America', 'New York')
+toronto = WeatherLocation('North America', 'Toronto')
+london = WeatherLocation('Europe', 'London')
+dublin = WeatherLocation('Europe', 'Dublin')
+brasilia = WeatherLocation('South America', 'Brasila')
+quito = WeatherLocation('South America', 'Quito')
+
+tokyo.reports.append(Report(80.0))
+newyork.reports.append(Report(75))
+quito.reports.append(Report(85))
+
+sess = create_session()
+for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
+ sess.save(c)
+sess.flush()
+
+sess.clear()
+
+t = sess.query(WeatherLocation).get(tokyo.id)
+assert t.city == tokyo.city
+assert t.reports[0].temperature == 80.0
+
+north_american_cities = sess.query(WeatherLocation).filter(WeatherLocation.continent == 'North America')
+assert [c.city for c in north_american_cities] == ['New York', 'Toronto']
+
+asia_and_europe = sess.query(WeatherLocation).filter(WeatherLocation.continent.in_('Europe', 'Asia'))
+assert set([c.city for c in asia_and_europe]) == set(['Tokyo', 'London', 'Dublin'])
+
self.save_obj([obj], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
return
- connection = uowtransaction.transaction.connection(self)
-
+ if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
+ tups = [(obj, connection_callable(self, obj)) for obj in objects]
+ else:
+ connection = uowtransaction.transaction.connection(self)
+ tups = [(obj, connection) for obj in objects]
+
if not postupdate:
- for obj in objects:
+ for obj, connection in tups:
if not has_identity(obj):
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.before_insert(mapper, connection, obj)
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.before_update(mapper, connection, obj)
- for obj in objects:
+ for obj, connection in tups:
# detect if we have a "pending" instance (i.e. has no instance_key attached to it),
# and another instance with the same identity key already exists as persistent. convert to an
# UPDATE if so.
insert = []
update = []
- for obj in objects:
+ for obj, connection in tups:
mapper = object_mapper(obj)
if table not in mapper.tables or not mapper._has_pks(table):
continue
if hasdata:
# if none of the attributes changed, dont even
# add the row to be updated.
- update.append((obj, params, mapper))
+ update.append((obj, params, mapper, connection))
else:
- insert.append((obj, params, mapper))
+ insert.append((obj, params, mapper, connection))
if len(update):
mapper = table_to_mapper[table]
return 0
update.sort(comparator)
for rec in update:
- (obj, params, mapper) = rec
+ (obj, params, mapper, connection) = rec
c = connection.execute(statement, params)
mapper._postfetch(connection, table, obj, c, c.last_updated_params())
- updated_objects.add(obj)
+ updated_objects.add((obj, connection))
rows += c.rowcount
if c.supports_sane_rowcount() and rows != len(update):
return cmp(a[0]._sa_insert_order, b[0]._sa_insert_order)
insert.sort(comparator)
for rec in insert:
- (obj, params, mapper) = rec
+ (obj, params, mapper, connection) = rec
c = connection.execute(statement, params)
primary_key = c.last_inserted_ids()
if primary_key is not None:
mapper._synchronizer.execute(obj, obj)
sync(mapper)
- inserted_objects.add(obj)
+ inserted_objects.add((obj, connection))
if not postupdate:
- for obj in inserted_objects:
+ for obj, connection in inserted_objects:
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.after_insert(mapper, connection, obj)
- for obj in updated_objects:
+ for obj, connection in updated_objects:
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.after_update(mapper, connection, obj)
if self.__should_log_debug:
self.__log_debug("delete_obj() start")
- connection = uowtransaction.transaction.connection(self)
+ if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
+ tups = [(obj, connection_callable(self, obj)) for obj in objects]
+ else:
+ connection = uowtransaction.transaction.connection(self)
+ tups = [(obj, connection) for obj in objects]
- for obj in objects:
+ for (obj, connection) in tups:
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.before_delete(mapper, connection, obj)
table_to_mapper.setdefault(t, mapper)
for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=True):
- delete = []
- for obj in objects:
+ delete = {}
+ for (obj, connection) in tups:
mapper = object_mapper(obj)
if table not in mapper.tables or not mapper._has_pks(table):
continue
if not hasattr(obj, '_instance_key'):
continue
else:
- delete.append(params)
+ delete.setdefault(connection, []).append(params)
for col in mapper.pks_by_table[table]:
params[col.key] = mapper.get_attr_by_column(obj, col)
if mapper.version_id_col is not None:
params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col)
- deleted_objects.add(obj)
- if len(delete):
+ deleted_objects.add((obj, connection))
+ for connection, del_objects in delete.iteritems():
mapper = table_to_mapper[table]
def comparator(a, b):
for col in mapper.pks_by_table[table]:
if x != 0:
return x
return 0
- delete.sort(comparator)
+ del_objects.sort(comparator)
clause = sql.and_()
for col in mapper.pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True))
if mapper.version_id_col is not None:
clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True))
statement = table.delete(clause)
- c = connection.execute(statement, delete)
- if c.supports_sane_rowcount() and c.rowcount != len(delete):
+ c = connection.execute(statement, del_objects)
+ if c.supports_sane_rowcount() and c.rowcount != len(del_objects):
raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.rowcount, len(delete)))
- for obj in deleted_objects:
+ for obj, connection in deleted_objects:
for mapper in object_mapper(obj).iterate_to_root():
mapper.extension.after_delete(mapper, connection, obj)
key = self.mapper.identity_key_from_primary_key(ident)
return self._get(key, ident, **kwargs)
- def load(self, ident, **kwargs):
+ def load(self, ident, raiseerr=True, **kwargs):
"""Return an instance of the object based on the given
identifier.
return ret
key = self.mapper.identity_key_from_primary_key(ident)
instance = self._get(key, ident, reload=True, **kwargs)
- if instance is None:
+ if instance is None and raiseerr:
raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
return instance
statement.use_labels = True
if self.session.autoflush:
self.session.flush()
+ return self._execute_and_instances(statement)
+
+ def _execute_and_instances(self, statement):
result = self.session.execute(statement, params=self._params, mapper=self.mapper)
try:
return iter(self.instances(result))
finally:
result.close()
-
def instances(self, cursor, *mappers_or_columns, **kwargs):
"""Return a list of mapped instances corresponding to the rows
in a given *cursor* (i.e. ``ResultProxy``).
self.autoflush = autoflush
self.nested = nested
- def connection(self, mapper_or_class, entity_name=None):
+ def connection(self, mapper_or_class, entity_name=None, **kwargs):
if isinstance(mapper_or_class, type):
mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name)
- engine = self.session.get_bind(mapper_or_class)
+ engine = self.session.get_bind(mapper_or_class, **kwargs)
return self.get_or_add(engine)
def _begin(self, **kwargs):
self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map)
self.bind = bind
- self.binds = {}
+ self.__binds = {}
self.echo_uow = echo_uow
self.weak_identity_map = weak_identity_map
self.transaction = None
self.autoflush = autoflush
self.transactional = transactional or autoflush
self.twophase = twophase
+ self._query_cls = query.Query
+ self._mapper_flush_opts = {}
if self.transactional:
self.begin()
_sessions[self.hash_key] = self
self.transaction = self.transaction.commit()
if self.transaction is None and self.transactional:
self.begin()
-
+
def connection(self, mapper=None, **kwargs):
"""Return a ``Connection`` corresponding to this session's
transactional context, if any.
if isinstance(mapper, type):
mapper = _class_mapper(mapper, entity_name=entity_name)
- self.binds[mapper] = bind
+ self.__binds[mapper] = bind
def bind_table(self, table, bind):
"""Bind the given `table` to the given ``Engine`` or ``Connection``.
given `bind`.
"""
- self.binds[table] = bind
+ self.__binds[table] = bind
def get_bind(self, mapper):
"""Return the ``Engine`` or ``Connection`` which is used to execute
return self.bind
else:
raise exceptions.InvalidRequestError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()")
- elif self.binds.has_key(mapper):
- return self.binds[mapper]
- elif self.binds.has_key(mapper.mapped_table):
- return self.binds[mapper.mapped_table]
+ elif self.__binds.has_key(mapper):
+ return self.__binds[mapper]
+ elif self.__binds.has_key(mapper.mapped_table):
+ return self.__binds[mapper.mapped_table]
elif self.bind is not None:
return self.bind
else:
entity_name = kwargs.pop('entity_name', None)
if isinstance(mapper_or_class, type):
- q = query.Query(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs)
+ q = self._query_cls(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs)
else:
- q = query.Query(mapper_or_class, self, **kwargs)
+ q = self._query_cls(mapper_or_class, self, **kwargs)
for ent in addtl_entities:
q = q.add_entity(ent)
--- /dev/null
+from sqlalchemy.orm.session import Session
+from sqlalchemy.orm import Query
+
+class ShardedSession(Session):
+ def __init__(self, shard_chooser, id_chooser, query_chooser, **kwargs):
+ """construct a ShardedSession.
+
+ shard_chooser
+ a callable which, passed a Mapper and a mapped instance, returns a
+ shard ID. this id may be based off of the attributes present within the
+ object, or on some round-robin scheme. If the scheme is based on a
+ selection, it should set whatever state on the instance to mark it in
+ the future as participating in that shard.
+
+ id_chooser
+ a callable, passed a tuple of identity values, which should return
+ a list of shard ids where the ID might reside. The databases will
+ be queried in the order of this listing.
+
+ query_chooser
+ for a given Query, returns the list of shard_ids where the query
+ should be issued. Results from all shards returned will be
+ combined together into a single listing.
+
+ """
+ super(ShardedSession, self).__init__(**kwargs)
+ self.shard_chooser = shard_chooser
+ self.id_chooser = id_chooser
+ self.query_chooser = query_chooser
+ self.__binds = {}
+ self._mapper_flush_opts = {'connection_callable':self.connection}
+ self._query_cls = ShardedQuery
+
+ def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
+ if shard_id is None:
+ shard_id = self.shard_chooser(mapper, instance)
+
+ if self.transaction is not None:
+ return self.transaction.connection(mapper, shard_id=shard_id)
+ else:
+ return self.get_bind(mapper, shard_id=shard_id, instance=instance).contextual_connect(**kwargs)
+
+ def get_bind(self, mapper, shard_id=None, instance=None):
+ if shard_id is None:
+ shard_id = self.shard_chooser(mapper, instance)
+ return self.__binds[shard_id]
+
+ def bind_shard(self, shard_id, bind):
+ self.__binds[shard_id] = bind
+
+class ShardedQuery(Query):
+ def __init__(self, *args, **kwargs):
+ super(ShardedQuery, self).__init__(*args, **kwargs)
+ self.id_chooser = self.session.id_chooser
+ self.query_chooser = self.session.query_chooser
+ self._shard_id = None
+
+ def _clone(self):
+ q = ShardedQuery.__new__(ShardedQuery)
+ q.__dict__ = self.__dict__.copy()
+ return q
+
+ def set_shard(self, shard_id):
+ """return a new query, limited to a single shard ID.
+
+ all subsequent operations with the returned query will
+ be against the single shard regardless of other state.
+ """
+
+ q = self._clone()
+ q._shard_id = shard_id
+ return q
+
+ def _execute_and_instances(self, statement):
+ if self._shard_id is not None:
+ result = self.session.connection(mapper=self.mapper, shard_id=self._shard_id).execute(statement, **self._params)
+ try:
+ return iter(self.instances(result))
+ finally:
+ result.close()
+ else:
+ partial = []
+ for shard_id in self.query_chooser(self):
+ result = self.session.connection(mapper=self.mapper, shard_id=shard_id).execute(statement, **self._params)
+ try:
+ partial = partial + list(self.instances(result))
+ finally:
+ result.close()
+ # if some kind of in memory 'sorting' were done, this is where it would happen
+ return iter(partial)
+
+ def get(self, ident, **kwargs):
+ if self._shard_id is not None:
+ return super(ShardedQuery, self).get(ident)
+ else:
+ for shard_id in self.id_chooser(ident):
+ o = self.set_shard(shard_id).get(ident, **kwargs)
+ if o is not None:
+ return o
+ else:
+ return None
+
+ def load(self, ident, **kwargs):
+ if self._shard_id is not None:
+ return super(ShardedQuery, self).load(ident)
+ else:
+ for shard_id in self.id_chooser(ident):
+ o = self.set_shard(shard_id).load(ident, raiseerr=False, **kwargs)
+ if o is not None:
+ return o
+ else:
+ raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
def __init__(self, uow, session):
self.uow = uow
self.session = session
+ self.mapper_flush_opts = session._mapper_flush_opts
# stores tuples of mapper/dependent mapper pairs,
# representing a partial ordering fed into topological sort