From: Mike Bayer Date: Tue, 24 Jul 2007 04:05:55 +0000 (+0000) Subject: - a rudimental sharding (horizontal scaling) system is introduced. This system X-Git-Tag: rel_0_4_6~35 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=91803812e8a92fa0b7d6137388e107abd96046d1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - a rudimental sharding (horizontal scaling) system is introduced. This system uses a modified Session which can distribute read and write operations among multiple databases, based on user-defined functions defining the "sharding strategy". Instances and their dependents can be distributed and queried among multiple databases based on attribute values, round-robin approaches or any other user-defined system. [ticket:618] --- diff --git a/CHANGES b/CHANGES index 092ee176ec..668c94626c 100644 --- a/CHANGES +++ b/CHANGES @@ -67,6 +67,13 @@ - added eagerload_all(), allows eagerload_all('x.y.z') to specify eager loading of all properties in the given path + + - a rudimental sharding (horizontal scaling) system is introduced. This system + uses a modified Session which can distribute read and write operations among + multiple databases, based on user-defined functions defining the + "sharding strategy". Instances and their dependents can be distributed + and queried among multiple databases based on attribute values, round-robin + approaches or any other user-defined system. [ticket:618] - Eager loading has been enhanced to allow even more joins in more places. It now functions at any arbitrary depth along self-referential @@ -76,7 +83,6 @@ themselves are generated at compile time using a simple counting scheme now and are a lot easier on the eyes, as well as of course completely deterministic. [ticket:659] - - added composite column properties. This allows you to create a type which is represented by more than one column, when using the diff --git a/examples/sharding/attribute_shard.py b/examples/sharding/attribute_shard.py new file mode 100644 index 0000000000..1f80c5445e --- /dev/null +++ b/examples/sharding/attribute_shard.py @@ -0,0 +1,194 @@ +"""a basic example of using the SQLAlchemy Sharding API. +Sharding refers to horizontally scaling data across multiple +databases. + +In this example, four sqlite databases will store information about +weather data on a database-per-continent basis. + +To set up a sharding system, you need: + 1. multiple databases, each assined a 'shard id' + 2. a function which can return a single shard id, given an instance + to be saved; this is called "shard_chooser" + 3. a function which can return a list of shard ids which apply to a particular + instance identifier; this is called "id_chooser". If it returns all shard ids, + all shards will be searched. + 4. a function which can return a list of shard ids to try, given a particular + Query ("query_chooser"). If it returns all shard ids, all shards will be + queried and the results joined together. +""" + +# step 1. imports +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy.orm.shard import ShardedSession +from sqlalchemy.sql import ColumnOperators +import datetime, operator + +# step 2. databases +echo = True +db1 = create_engine('sqlite:///shard1.db', echo=echo) +db2 = create_engine('sqlite:///shard2.db', echo=echo) +db3 = create_engine('sqlite:///shard3.db', echo=echo) +db4 = create_engine('sqlite:///shard4.db', echo=echo) + + +# step 3. create session function. this binds the shard ids +# to databases within a ShardedSession and returns it. +def create_session(): + s = ShardedSession(shard_chooser, id_chooser, query_chooser) + s.bind_shard('north_america', db1) + s.bind_shard('asia', db2) + s.bind_shard('europe', db3) + s.bind_shard('south_america', db4) + return s + +# step 4. table setup. +meta = MetaData() + +# we need a way to create identifiers which are unique across all +# databases. one easy way would be to just use a composite primary key, where one +# value is the shard id. but here, we'll show something more "generic", an +# id generation function. we'll use a simplistic "id table" stored in database +# #1. Any other method will do just as well; UUID, hilo, application-specific, etc. + +ids = Table('ids', meta, + Column('nextid', Integer, nullable=False)) + +def id_generator(ctx): + # in reality, might want to use a separate transaction for this. + c = db1.connect() + nextid = c.execute(ids.select(for_update=True)).scalar() + c.execute(ids.update(values={ids.c.nextid : ids.c.nextid + 1})) + return nextid + +# table setup. we'll store a lead table of continents/cities, +# and a secondary table storing locations. +# a particular row will be placed in the database whose shard id corresponds to the +# 'continent'. in this setup, secondary rows in 'weather_reports' will +# be placed in the same DB as that of the parent, but this can be changed +# if you're willing to write more complex sharding functions. + +weather_locations = Table("weather_locations", meta, + Column('id', Integer, primary_key=True, default=id_generator), + Column('continent', String(30), nullable=False), + Column('city', String(50), nullable=False) + ) + +weather_reports = Table("weather_reports", meta, + Column('id', Integer, primary_key=True), + Column('location_id', Integer, ForeignKey('weather_locations.id')), + Column('temperature', Numeric), + Column('report_time', DateTime, default=datetime.datetime.now), +) + +# create tables +for db in (db1, db2, db3, db4): + meta.drop_all(db) + meta.create_all(db) + +# establish initial "id" in db1 +db1.execute(ids.insert(), nextid=1) + + +# step 5. define sharding functions. + +# we'll use a straight mapping of a particular set of "country" +# attributes to shard id. +shard_lookup = { + 'North America':'north_america', + 'Asia':'asia', + 'Europe':'europe', + 'South America':'south_america' +} + +# shard_chooser - looks at the given instance and returns a shard id +# note that we need to define conditions for +# the WeatherLocation class, as well as our secondary Report class which will +# point back to its WeatherLocation via its 'location' attribute. +def shard_chooser(mapper, instance): + if isinstance(instance, WeatherLocation): + return shard_lookup[instance.continent] + else: + return shard_chooser(mapper, instance.location) + +# id_chooser. given a primary key, returns a list of shards +# to search. here, we don't have any particular information from a +# pk so we just return all shard ids. often, youd want to do some +# kind of round-robin strategy here so that requests are evenly +# distributed among DBs +def id_chooser(ident): + return ['north_america', 'asia', 'europe', 'south_america'] + +# query_chooser. this also returns a list of shard ids, which can +# just be all of them. but here we'll search into the Query in order +# to try to narrow down the list of shards to query. +def query_chooser(query): + ids = [] + + # here we will traverse through the query's criterion, searching + # for SQL constructs. we'll grab continent names as we find them + # and convert to shard ids + class FindContinent(sql.ClauseVisitor): + def visit_binary(self, binary): + if binary.left is weather_locations.c.continent: + if binary.operator == operator.eq: + ids.append(shard_lookup[binary.right.value]) + elif binary.operator == ColumnOperators.in_op: + for bind in binary.right.clauses: + ids.append(shard_lookup[bind.value]) + + FindContinent().traverse(query._criterion) + if len(ids) == 0: + return ['north_america', 'asia', 'europe', 'south_america'] + else: + return ids + +# step 6. mapped classes. +class WeatherLocation(object): + def __init__(self, continent, city): + self.continent = continent + self.city = city + +class Report(object): + def __init__(self, temperature): + self.temperature = temperature + +# step 7. mappers +mapper(WeatherLocation, weather_locations, properties={ + 'reports':relation(Report, backref='location') +}) + +mapper(Report, weather_reports) + + +# save and load objects! + +tokyo = WeatherLocation('Asia', 'Tokyo') +newyork = WeatherLocation('North America', 'New York') +toronto = WeatherLocation('North America', 'Toronto') +london = WeatherLocation('Europe', 'London') +dublin = WeatherLocation('Europe', 'Dublin') +brasilia = WeatherLocation('South America', 'Brasila') +quito = WeatherLocation('South America', 'Quito') + +tokyo.reports.append(Report(80.0)) +newyork.reports.append(Report(75)) +quito.reports.append(Report(85)) + +sess = create_session() +for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]: + sess.save(c) +sess.flush() + +sess.clear() + +t = sess.query(WeatherLocation).get(tokyo.id) +assert t.city == tokyo.city +assert t.reports[0].temperature == 80.0 + +north_american_cities = sess.query(WeatherLocation).filter(WeatherLocation.continent == 'North America') +assert [c.city for c in north_american_cities] == ['New York', 'Toronto'] + +asia_and_europe = sess.query(WeatherLocation).filter(WeatherLocation.continent.in_('Europe', 'Asia')) +assert set([c.city for c in asia_and_europe]) == set(['Tokyo', 'London', 'Dublin']) + diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 62b4d7d896..e035537351 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1095,10 +1095,15 @@ class Mapper(object): self.save_obj([obj], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) return - connection = uowtransaction.transaction.connection(self) - + if 'connection_callable' in uowtransaction.mapper_flush_opts: + connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] + tups = [(obj, connection_callable(self, obj)) for obj in objects] + else: + connection = uowtransaction.transaction.connection(self) + tups = [(obj, connection) for obj in objects] + if not postupdate: - for obj in objects: + for obj, connection in tups: if not has_identity(obj): for mapper in object_mapper(obj).iterate_to_root(): mapper.extension.before_insert(mapper, connection, obj) @@ -1106,7 +1111,7 @@ class Mapper(object): for mapper in object_mapper(obj).iterate_to_root(): mapper.extension.before_update(mapper, connection, obj) - for obj in objects: + for obj, connection in tups: # detect if we have a "pending" instance (i.e. has no instance_key attached to it), # and another instance with the same identity key already exists as persistent. convert to an # UPDATE if so. @@ -1137,7 +1142,7 @@ class Mapper(object): insert = [] update = [] - for obj in objects: + for obj, connection in tups: mapper = object_mapper(obj) if table not in mapper.tables or not mapper._has_pks(table): continue @@ -1215,9 +1220,9 @@ class Mapper(object): if hasdata: # if none of the attributes changed, dont even # add the row to be updated. - update.append((obj, params, mapper)) + update.append((obj, params, mapper, connection)) else: - insert.append((obj, params, mapper)) + insert.append((obj, params, mapper, connection)) if len(update): mapper = table_to_mapper[table] @@ -1237,11 +1242,11 @@ class Mapper(object): return 0 update.sort(comparator) for rec in update: - (obj, params, mapper) = rec + (obj, params, mapper, connection) = rec c = connection.execute(statement, params) mapper._postfetch(connection, table, obj, c, c.last_updated_params()) - updated_objects.add(obj) + updated_objects.add((obj, connection)) rows += c.rowcount if c.supports_sane_rowcount() and rows != len(update): @@ -1253,7 +1258,7 @@ class Mapper(object): return cmp(a[0]._sa_insert_order, b[0]._sa_insert_order) insert.sort(comparator) for rec in insert: - (obj, params, mapper) = rec + (obj, params, mapper, connection) = rec c = connection.execute(statement, params) primary_key = c.last_inserted_ids() if primary_key is not None: @@ -1275,12 +1280,12 @@ class Mapper(object): mapper._synchronizer.execute(obj, obj) sync(mapper) - inserted_objects.add(obj) + inserted_objects.add((obj, connection)) if not postupdate: - for obj in inserted_objects: + for obj, connection in inserted_objects: for mapper in object_mapper(obj).iterate_to_root(): mapper.extension.after_insert(mapper, connection, obj) - for obj in updated_objects: + for obj, connection in updated_objects: for mapper in object_mapper(obj).iterate_to_root(): mapper.extension.after_update(mapper, connection, obj) @@ -1320,9 +1325,14 @@ class Mapper(object): if self.__should_log_debug: self.__log_debug("delete_obj() start") - connection = uowtransaction.transaction.connection(self) + if 'connection_callable' in uowtransaction.mapper_flush_opts: + connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] + tups = [(obj, connection_callable(self, obj)) for obj in objects] + else: + connection = uowtransaction.transaction.connection(self) + tups = [(obj, connection) for obj in objects] - for obj in objects: + for (obj, connection) in tups: for mapper in object_mapper(obj).iterate_to_root(): mapper.extension.before_delete(mapper, connection, obj) @@ -1333,8 +1343,8 @@ class Mapper(object): table_to_mapper.setdefault(t, mapper) for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=True): - delete = [] - for obj in objects: + delete = {} + for (obj, connection) in tups: mapper = object_mapper(obj) if table not in mapper.tables or not mapper._has_pks(table): continue @@ -1343,13 +1353,13 @@ class Mapper(object): if not hasattr(obj, '_instance_key'): continue else: - delete.append(params) + delete.setdefault(connection, []).append(params) for col in mapper.pks_by_table[table]: params[col.key] = mapper.get_attr_by_column(obj, col) if mapper.version_id_col is not None: params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col) - deleted_objects.add(obj) - if len(delete): + deleted_objects.add((obj, connection)) + for connection, del_objects in delete.iteritems(): mapper = table_to_mapper[table] def comparator(a, b): for col in mapper.pks_by_table[table]: @@ -1357,18 +1367,18 @@ class Mapper(object): if x != 0: return x return 0 - delete.sort(comparator) + del_objects.sort(comparator) clause = sql.and_() for col in mapper.pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True)) if mapper.version_id_col is not None: clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True)) statement = table.delete(clause) - c = connection.execute(statement, delete) - if c.supports_sane_rowcount() and c.rowcount != len(delete): + c = connection.execute(statement, del_objects) + if c.supports_sane_rowcount() and c.rowcount != len(del_objects): raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.rowcount, len(delete))) - for obj in deleted_objects: + for obj, connection in deleted_objects: for mapper in object_mapper(obj).iterate_to_root(): mapper.extension.after_delete(mapper, connection, obj) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 86e6d522f3..68a71e5655 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -81,7 +81,7 @@ class Query(object): key = self.mapper.identity_key_from_primary_key(ident) return self._get(key, ident, **kwargs) - def load(self, ident, **kwargs): + def load(self, ident, raiseerr=True, **kwargs): """Return an instance of the object based on the given identifier. @@ -97,7 +97,7 @@ class Query(object): return ret key = self.mapper.identity_key_from_primary_key(ident) instance = self._get(key, ident, reload=True, **kwargs) - if instance is None: + if instance is None and raiseerr: raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident)) return instance @@ -608,13 +608,15 @@ class Query(object): statement.use_labels = True if self.session.autoflush: self.session.flush() + return self._execute_and_instances(statement) + + def _execute_and_instances(self, statement): result = self.session.execute(statement, params=self._params, mapper=self.mapper) try: return iter(self.instances(result)) finally: result.close() - def instances(self, cursor, *mappers_or_columns, **kwargs): """Return a list of mapped instances corresponding to the rows in a given *cursor* (i.e. ``ResultProxy``). diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index d41eac2b58..6b5c4a0725 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -28,10 +28,10 @@ class SessionTransaction(object): self.autoflush = autoflush self.nested = nested - def connection(self, mapper_or_class, entity_name=None): + def connection(self, mapper_or_class, entity_name=None, **kwargs): if isinstance(mapper_or_class, type): mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name) - engine = self.session.get_bind(mapper_or_class) + engine = self.session.get_bind(mapper_or_class, **kwargs) return self.get_or_add(engine) def _begin(self, **kwargs): @@ -137,7 +137,7 @@ class Session(object): self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map) self.bind = bind - self.binds = {} + self.__binds = {} self.echo_uow = echo_uow self.weak_identity_map = weak_identity_map self.transaction = None @@ -145,6 +145,8 @@ class Session(object): self.autoflush = autoflush self.transactional = transactional or autoflush self.twophase = twophase + self._query_cls = query.Query + self._mapper_flush_opts = {} if self.transactional: self.begin() _sessions[self.hash_key] = self @@ -185,7 +187,7 @@ class Session(object): self.transaction = self.transaction.commit() if self.transaction is None and self.transactional: self.begin() - + def connection(self, mapper=None, **kwargs): """Return a ``Connection`` corresponding to this session's transactional context, if any. @@ -263,7 +265,7 @@ class Session(object): if isinstance(mapper, type): mapper = _class_mapper(mapper, entity_name=entity_name) - self.binds[mapper] = bind + self.__binds[mapper] = bind def bind_table(self, table, bind): """Bind the given `table` to the given ``Engine`` or ``Connection``. @@ -272,7 +274,7 @@ class Session(object): given `bind`. """ - self.binds[table] = bind + self.__binds[table] = bind def get_bind(self, mapper): """Return the ``Engine`` or ``Connection`` which is used to execute @@ -306,10 +308,10 @@ class Session(object): return self.bind else: raise exceptions.InvalidRequestError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()") - elif self.binds.has_key(mapper): - return self.binds[mapper] - elif self.binds.has_key(mapper.mapped_table): - return self.binds[mapper.mapped_table] + elif self.__binds.has_key(mapper): + return self.__binds[mapper] + elif self.__binds.has_key(mapper.mapped_table): + return self.__binds[mapper.mapped_table] elif self.bind is not None: return self.bind else: @@ -326,9 +328,9 @@ class Session(object): entity_name = kwargs.pop('entity_name', None) if isinstance(mapper_or_class, type): - q = query.Query(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs) + q = self._query_cls(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs) else: - q = query.Query(mapper_or_class, self, **kwargs) + q = self._query_cls(mapper_or_class, self, **kwargs) for ent in addtl_entities: q = q.add_entity(ent) diff --git a/lib/sqlalchemy/orm/shard.py b/lib/sqlalchemy/orm/shard.py new file mode 100644 index 0000000000..cc13f8c1fe --- /dev/null +++ b/lib/sqlalchemy/orm/shard.py @@ -0,0 +1,112 @@ +from sqlalchemy.orm.session import Session +from sqlalchemy.orm import Query + +class ShardedSession(Session): + def __init__(self, shard_chooser, id_chooser, query_chooser, **kwargs): + """construct a ShardedSession. + + shard_chooser + a callable which, passed a Mapper and a mapped instance, returns a + shard ID. this id may be based off of the attributes present within the + object, or on some round-robin scheme. If the scheme is based on a + selection, it should set whatever state on the instance to mark it in + the future as participating in that shard. + + id_chooser + a callable, passed a tuple of identity values, which should return + a list of shard ids where the ID might reside. The databases will + be queried in the order of this listing. + + query_chooser + for a given Query, returns the list of shard_ids where the query + should be issued. Results from all shards returned will be + combined together into a single listing. + + """ + super(ShardedSession, self).__init__(**kwargs) + self.shard_chooser = shard_chooser + self.id_chooser = id_chooser + self.query_chooser = query_chooser + self.__binds = {} + self._mapper_flush_opts = {'connection_callable':self.connection} + self._query_cls = ShardedQuery + + def connection(self, mapper=None, instance=None, shard_id=None, **kwargs): + if shard_id is None: + shard_id = self.shard_chooser(mapper, instance) + + if self.transaction is not None: + return self.transaction.connection(mapper, shard_id=shard_id) + else: + return self.get_bind(mapper, shard_id=shard_id, instance=instance).contextual_connect(**kwargs) + + def get_bind(self, mapper, shard_id=None, instance=None): + if shard_id is None: + shard_id = self.shard_chooser(mapper, instance) + return self.__binds[shard_id] + + def bind_shard(self, shard_id, bind): + self.__binds[shard_id] = bind + +class ShardedQuery(Query): + def __init__(self, *args, **kwargs): + super(ShardedQuery, self).__init__(*args, **kwargs) + self.id_chooser = self.session.id_chooser + self.query_chooser = self.session.query_chooser + self._shard_id = None + + def _clone(self): + q = ShardedQuery.__new__(ShardedQuery) + q.__dict__ = self.__dict__.copy() + return q + + def set_shard(self, shard_id): + """return a new query, limited to a single shard ID. + + all subsequent operations with the returned query will + be against the single shard regardless of other state. + """ + + q = self._clone() + q._shard_id = shard_id + return q + + def _execute_and_instances(self, statement): + if self._shard_id is not None: + result = self.session.connection(mapper=self.mapper, shard_id=self._shard_id).execute(statement, **self._params) + try: + return iter(self.instances(result)) + finally: + result.close() + else: + partial = [] + for shard_id in self.query_chooser(self): + result = self.session.connection(mapper=self.mapper, shard_id=shard_id).execute(statement, **self._params) + try: + partial = partial + list(self.instances(result)) + finally: + result.close() + # if some kind of in memory 'sorting' were done, this is where it would happen + return iter(partial) + + def get(self, ident, **kwargs): + if self._shard_id is not None: + return super(ShardedQuery, self).get(ident) + else: + for shard_id in self.id_chooser(ident): + o = self.set_shard(shard_id).get(ident, **kwargs) + if o is not None: + return o + else: + return None + + def load(self, ident, **kwargs): + if self._shard_id is not None: + return super(ShardedQuery, self).load(ident) + else: + for shard_id in self.id_chooser(ident): + o = self.set_shard(shard_id).load(ident, raiseerr=False, **kwargs) + if o is not None: + return o + else: + raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident)) diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 3f41d3f35b..f59042810a 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -224,6 +224,7 @@ class UOWTransaction(object): def __init__(self, uow, session): self.uow = uow self.session = session + self.mapper_flush_opts = session._mapper_flush_opts # stores tuples of mapper/dependent mapper pairs, # representing a partial ordering fed into topological sort