From: Mike Bayer Date: Thu, 14 Dec 2017 15:20:50 +0000 (-0500) Subject: Add an identity_token to the identity key X-Git-Tag: rel_1_2_0~5^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=50d9f1687a6e0c3ce9b62fe98b76b25af7b20c11;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add an identity_token to the identity key For the purposes of assisting with sharded setups, add a new member to the identity key that can be customized. this allows sharding across databases where the primary key space is shared. Change-Id: Iae3909f5d4c501b62c10d0371fbceb01abda51db Fixes: #4137 --- diff --git a/doc/build/changelog/migration_12.rst b/doc/build/changelog/migration_12.rst index 9c0218c3a0..cd2a49dd49 100644 --- a/doc/build/changelog/migration_12.rst +++ b/doc/build/changelog/migration_12.rst @@ -567,6 +567,55 @@ to query across the two proxies ``A.c_values``, ``AtoB.c_value``: :ticket:`3769` +.. _change_4137: + +Identity key enhancements to support sharding +--------------------------------------------- + +The identity key structure used by the ORM now contains an additional +member, so that two identical primary keys that originate from different +contexts can co-exist within the same identity map. + +The example at :ref:`examples_sharding` has been updated to illustrate this +behavior. The example shows a sharded class ``WeatherLocation`` that +refers to a dependent ``WeatherReport`` object, where the ``WeatherReport`` +class is mapped to a table that stores a simple integer primary key. Two +``WeatherReport`` objects from different databases may have the same +primary key value. The example now illustrates that a new ``identity_token`` +field tracks this difference so that the two objects can co-exist in the +same identity map:: + + tokyo = WeatherLocation('Asia', 'Tokyo') + newyork = WeatherLocation('North America', 'New York') + + tokyo.reports.append(Report(80.0)) + newyork.reports.append(Report(75)) + + sess = create_session() + + sess.add_all([tokyo, newyork, quito]) + + sess.commit() + + # the Report class uses a simple integer primary key. So across two + # databases, a primary key will be repeated. The "identity_token" tracks + # in memory that these two identical primary keys are local to different + # databases. + + newyork_report = newyork.reports[0] + tokyo_report = tokyo.reports[0] + + assert inspect(newyork_report).identity_key == (Report, (1, ), "north_america") + assert inspect(tokyo_report).identity_key == (Report, (1, ), "asia") + + # the token representing the originating shard is also available directly + + assert inspect(newyork_report).identity_token == "north_america" + assert inspect(tokyo_report).identity_token == "asia" + + +:ticket:`4137` + New Features and Improvements - Core ==================================== diff --git a/doc/build/changelog/unreleased_12/4137.rst b/doc/build/changelog/unreleased_12/4137.rst new file mode 100644 index 0000000000..c619914b83 --- /dev/null +++ b/doc/build/changelog/unreleased_12/4137.rst @@ -0,0 +1,17 @@ +.. change:: + :tags: orm, feature + :tickets: 4137 + + Added a new data member to the identity key tuple + used by the ORM's identity map, known as the + "identity_token". This token defaults to None but + may be used by database sharding schemes to differentiate + objects in memory with the same primary key that come + from different databases. The horizontal sharding + extension integrates this token applying the shard + identifier to it, thus allowing primary keys to be + duplicated across horizontally sharded backends. + + .. seealso:: + + :ref:`change_4137` \ No newline at end of file diff --git a/examples/sharding/attribute_shard.py b/examples/sharding/attribute_shard.py index 4ce8c247f4..cd9b14d5e2 100644 --- a/examples/sharding/attribute_shard.py +++ b/examples/sharding/attribute_shard.py @@ -1,14 +1,13 @@ - -# step 1. imports -from sqlalchemy import (create_engine, MetaData, Table, Column, Integer, - String, ForeignKey, Float, DateTime, event) -from sqlalchemy.orm import sessionmaker, mapper, relationship +from sqlalchemy import (create_engine, Table, Column, Integer, + String, ForeignKey, Float, DateTime) +from sqlalchemy.orm import sessionmaker, relationship from sqlalchemy.ext.horizontal_shard import ShardedSession from sqlalchemy.sql import operators, visitors +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import inspect import datetime -# step 2. databases. # db1 is used for id generation. The "pool_threadlocal" # causes the id_generator() to use the same connection as that # of an ongoing transaction within db1. @@ -19,61 +18,79 @@ db3 = create_engine('sqlite://', echo=echo) db4 = create_engine('sqlite://', echo=echo) -# step 3. create session function. this binds the shard ids +# create session function. this binds the shard ids # to databases within a ShardedSession and returns it. create_session = sessionmaker(class_=ShardedSession) create_session.configure(shards={ - 'north_america':db1, - 'asia':db2, - 'europe':db3, - 'south_america':db4 + 'north_america': db1, + 'asia': db2, + 'europe': db3, + 'south_america': db4 }) -# step 4. table setup. -meta = MetaData() +# mappings and tables +Base = declarative_base() -# we need a way to create identifiers which are unique across all -# databases. one easy way would be to just use a composite primary key, where one -# value is the shard id. but here, we'll show something more "generic", an -# id generation function. we'll use a simplistic "id table" stored in database -# #1. Any other method will do just as well; UUID, hilo, application-specific, etc. +# we need a way to create identifiers which are unique across all databases. +# one easy way would be to just use a composite primary key, where one value +# is the shard id. but here, we'll show something more "generic", an id +# generation function. we'll use a simplistic "id table" stored in database +# #1. Any other method will do just as well; UUID, hilo, application-specific, +# etc. -ids = Table('ids', meta, +ids = Table( + 'ids', Base.metadata, Column('nextid', Integer, nullable=False)) + def id_generator(ctx): # in reality, might want to use a separate transaction for this. - c = db1.connect() - nextid = c.execute(ids.select(for_update=True)).scalar() - c.execute(ids.update(values={ids.c.nextid : ids.c.nextid + 1})) + with db1.connect() as conn: + nextid = conn.scalar(ids.select(for_update=True)) + conn.execute(ids.update(values={ids.c.nextid: ids.c.nextid + 1})) return nextid -# table setup. we'll store a lead table of continents/cities, -# and a secondary table storing locations. -# a particular row will be placed in the database whose shard id corresponds to the -# 'continent'. in this setup, secondary rows in 'weather_reports' will -# be placed in the same DB as that of the parent, but this can be changed -# if you're willing to write more complex sharding functions. - -weather_locations = Table("weather_locations", meta, - Column('id', Integer, primary_key=True, default=id_generator), - Column('continent', String(30), nullable=False), - Column('city', String(50), nullable=False) - ) - -weather_reports = Table("weather_reports", meta, - Column('id', Integer, primary_key=True), - Column('location_id', Integer, ForeignKey('weather_locations.id')), - Column('temperature', Float), - Column('report_time', DateTime, default=datetime.datetime.now), -) +# table setup. we'll store a lead table of continents/cities, and a secondary +# table storing locations. a particular row will be placed in the database +# whose shard id corresponds to the 'continent'. in this setup, secondary rows +# in 'weather_reports' will be placed in the same DB as that of the parent, but +# this can be changed if you're willing to write more complex sharding +# functions. + + +class WeatherLocation(Base): + __tablename__ = "weather_locations" + + id = Column(Integer, primary_key=True, default=id_generator) + continent = Column(String(30), nullable=False) + city = Column(String(50), nullable=False) + + reports = relationship("Report", backref='location') + + def __init__(self, continent, city): + self.continent = continent + self.city = city + + +class Report(Base): + __tablename__ = "weather_reports" + + id = Column(Integer, primary_key=True) + location_id = Column( + 'location_id', Integer, ForeignKey('weather_locations.id')) + temperature = Column('temperature', Float) + report_time = Column( + 'report_time', DateTime, default=datetime.datetime.now) + + def __init__(self, temperature): + self.temperature = temperature # create tables for db in (db1, db2, db3, db4): - meta.drop_all(db) - meta.create_all(db) + Base.metadata.drop_all(db) + Base.metadata.create_all(db) # establish initial "id" in db1 db1.execute(ids.insert(), nextid=1) @@ -84,12 +101,13 @@ db1.execute(ids.insert(), nextid=1) # we'll use a straight mapping of a particular set of "country" # attributes to shard id. shard_lookup = { - 'North America':'north_america', - 'Asia':'asia', - 'Europe':'europe', - 'South America':'south_america' + 'North America': 'north_america', + 'Asia': 'asia', + 'Europe': 'europe', + 'South America': 'south_america' } + def shard_chooser(mapper, instance, clause=None): """shard chooser. @@ -104,6 +122,7 @@ def shard_chooser(mapper, instance, clause=None): else: return shard_chooser(mapper, instance.location) + def id_chooser(query, ident): """id chooser. @@ -116,6 +135,7 @@ def id_chooser(query, ident): """ return ['north_america', 'asia', 'europe', 'south_america'] + def query_chooser(query): """query chooser. @@ -133,9 +153,9 @@ def query_chooser(query): # statement column, adjusting for any annotations present. # (an annotation is an internal clone of a Column object # and occur when using ORM-mapped attributes like - # "WeatherLocation.continent"). A simpler comparison, though less accurate, - # would be "column.key == 'continent'". - if column.shares_lineage(weather_locations.c.continent): + # "WeatherLocation.continent"). A simpler comparison, though less + # accurate, would be "column.key == 'continent'". + if column.shares_lineage(WeatherLocation.__table__.c.continent): if operator == operators.eq: ids.append(shard_lookup[value]) elif operator == operators.in_op: @@ -146,6 +166,7 @@ def query_chooser(query): else: return ids + def _get_query_comparisons(query): """Search an orm.Query object for binary expressions. @@ -185,65 +206,39 @@ def _get_query_comparisons(query): binary.operator == operators.in_op and \ hasattr(binary.right, 'clauses'): comparisons.append( - (binary.left, binary.operator, + ( + binary.left, binary.operator, tuple(binds[bind] for bind in binary.right.clauses) ) ) elif binary.left in clauses and binary.right in binds: comparisons.append( - (binary.left, binary.operator,binds[binary.right]) + (binary.left, binary.operator, binds[binary.right]) ) elif binary.left in binds and binary.right in clauses: comparisons.append( - (binary.right, binary.operator,binds[binary.left]) + (binary.right, binary.operator, binds[binary.left]) ) # here we will traverse through the query's criterion, searching # for SQL constructs. We will place simple column comparisons # into a list. if query._criterion is not None: - visitors.traverse_depthfirst(query._criterion, {}, - {'bindparam':visit_bindparam, - 'binary':visit_binary, - 'column':visit_column - } + visitors.traverse_depthfirst( + query._criterion, {}, + {'bindparam': visit_bindparam, + 'binary': visit_binary, + 'column': visit_column} ) return comparisons # further configure create_session to use these functions create_session.configure( - shard_chooser=shard_chooser, - id_chooser=id_chooser, - query_chooser=query_chooser - ) - -# step 6. mapped classes. -class WeatherLocation(object): - def __init__(self, continent, city): - self.continent = continent - self.city = city - -class Report(object): - def __init__(self, temperature): - self.temperature = temperature - -# step 7. mappers -mapper(WeatherLocation, weather_locations, properties={ - 'reports':relationship(Report, backref='location') -}) - -mapper(Report, weather_reports) - -# step 8 (optional), events. The "shard_id" is placed -# in the QueryContext where it can be intercepted and associated -# with objects, if needed. - -def add_shard_id(instance, ctx): - instance.shard_id = ctx.attributes["shard_id"] - -event.listen(WeatherLocation, "load", add_shard_id) -event.listen(Report, "load", add_shard_id) + shard_chooser=shard_chooser, + id_chooser=id_chooser, + query_chooser=query_chooser +) # save and load objects! @@ -260,21 +255,33 @@ newyork.reports.append(Report(75)) quito.reports.append(Report(85)) sess = create_session() -for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]: - sess.add(c) -sess.commit() -tokyo_id = tokyo.id +sess.add_all([tokyo, newyork, toronto, london, dublin, brasilia, quito]) -sess.close() +sess.commit() -t = sess.query(WeatherLocation).get(tokyo_id) +t = sess.query(WeatherLocation).get(tokyo.id) assert t.city == tokyo.city assert t.reports[0].temperature == 80.0 -north_american_cities = sess.query(WeatherLocation).filter(WeatherLocation.continent == 'North America') -assert [c.city for c in north_american_cities] == ['New York', 'Toronto'] +north_american_cities = sess.query(WeatherLocation).filter( + WeatherLocation.continent == 'North America') +assert {c.city for c in north_american_cities} == {'New York', 'Toronto'} + +asia_and_europe = sess.query(WeatherLocation).filter( + WeatherLocation.continent.in_(['Europe', 'Asia'])) +assert {c.city for c in asia_and_europe} == {'Tokyo', 'London', 'Dublin'} + +# the Report class uses a simple integer primary key. So across two databases, +# a primary key will be repeated. The "identity_token" tracks in memory +# that these two identical primary keys are local to different databases. +newyork_report = newyork.reports[0] +tokyo_report = tokyo.reports[0] + +assert inspect(newyork_report).identity_key == (Report, (1, ), "north_america") +assert inspect(tokyo_report).identity_key == (Report, (1, ), "asia") -asia_and_europe = sess.query(WeatherLocation).filter(WeatherLocation.continent.in_(['Europe', 'Asia'])) -assert set([c.city for c in asia_and_europe]) == set(['Tokyo', 'London', 'Dublin']) +# the token representing the originating shard is also available directly +assert inspect(newyork_report).identity_token == "north_america" +assert inspect(tokyo_report).identity_token == "asia" diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 8902ae6065..c5cf98b404 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -15,6 +15,7 @@ the source distribution. """ +from .. import inspect from .. import util from ..orm.session import Session from ..orm.query import Query @@ -42,7 +43,7 @@ class ShardedQuery(Query): def _execute_and_instances(self, context): def iter_for_shard(shard_id): - context.attributes['shard_id'] = shard_id + context.attributes['shard_id'] = context.identity_token = shard_id result = self._connection_from_session( mapper=self._mapper_zero(), shard_id=shard_id).execute( @@ -62,6 +63,9 @@ class ShardedQuery(Query): return iter(partial) def _get_impl(self, ident, fallback_fn): + # TODO: the "ident" here should be getting the identity token + # which indicates that this area can likely be simplified, as the + # token will fall through into _execute_and_instances def _fallback(query, ident): if self._shard_id is not None: return fallback_fn(self, ident) @@ -75,7 +79,13 @@ class ShardedQuery(Query): else: return None - return super(ShardedQuery, self)._get_impl(ident, _fallback) + if self._shard_id is not None: + identity_token = self._shard_id + else: + identity_token = None + + return super(ShardedQuery, self)._get_impl( + ident, _fallback, identity_token=identity_token) class ShardedSession(Session): @@ -112,9 +122,24 @@ class ShardedSession(Session): for k in shards: self.bind_shard(k, shards[k]) + def _choose_shard_and_assign(self, mapper, instance, **kw): + if instance is not None: + state = inspect(instance) + if state.key: + token = state.key[2] + assert token is not None + return token + elif state.identity_token: + return state.identity_token + + shard_id = self.shard_chooser(mapper, instance, **kw) + if instance is not None: + state.identity_token = shard_id + return shard_id + def connection(self, mapper=None, instance=None, shard_id=None, **kwargs): if shard_id is None: - shard_id = self.shard_chooser(mapper, instance) + shard_id = self._choose_shard_and_assign(mapper, instance) if self.transaction is not None: return self.transaction.connection(mapper, shard_id=shard_id) @@ -128,7 +153,8 @@ class ShardedSession(Session): def get_bind(self, mapper, shard_id=None, instance=None, clause=None, **kw): if shard_id is None: - shard_id = self.shard_chooser(mapper, instance, clause=clause) + shard_id = self._choose_shard_and_assign( + mapper, instance, clause=clause) return self.__binds[shard_id] def bind_shard(self, shard_id, bind): diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index a25a1422d5..a23cafac2c 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -369,6 +369,7 @@ def _instance_processor( session_id = context.session.hash_key version_check = context.version_check runid = context.runid + identity_token = context.identity_token if not refresh_state and _polymorphic_from is not None: key = ('loader', path.path) @@ -430,7 +431,8 @@ def _instance_processor( # session, or we have to create a new one identitykey = ( identity_class, - tuple([row[column] for column in pk_cols]) + tuple([row[column] for column in pk_cols]), + identity_token ) instance = session_identity_map.get(identitykey) @@ -464,6 +466,7 @@ def _instance_processor( dict_ = instance_dict(instance) state = instance_state(instance) state.key = identitykey + state.identity_token = identity_token # attach instance to session. state.session_id = session_id diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 31a93f42e7..8317c914b2 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -2506,7 +2506,7 @@ class Mapper(InspectionAttr): else: return True - def identity_key_from_row(self, row, adapter=None): + def identity_key_from_row(self, row, identity_token=None, adapter=None): """Return an identity-map key for use in storing/retrieving an item from the identity map. @@ -2522,16 +2522,16 @@ class Mapper(InspectionAttr): pk_cols = [adapter.columns[c] for c in pk_cols] return self._identity_class, \ - tuple(row[column] for column in pk_cols) + tuple(row[column] for column in pk_cols), identity_token - def identity_key_from_primary_key(self, primary_key): + def identity_key_from_primary_key(self, primary_key, identity_token=None): """Return an identity-map key for use in storing/retrieving an item from an identity map. :param primary_key: A list of values indicating the identifier. """ - return self._identity_class, tuple(primary_key) + return self._identity_class, tuple(primary_key), identity_token def identity_key_from_instance(self, instance): """Return the identity key for the given instance, based on @@ -2546,17 +2546,18 @@ class Mapper(InspectionAttr): attribute name `key`. """ - return self.identity_key_from_primary_key( - self.primary_key_from_instance(instance)) + state = attributes.instance_state(instance) + return self._identity_key_from_state(state, attributes.PASSIVE_OFF) - def _identity_key_from_state(self, state): + def _identity_key_from_state( + self, state, passive=attributes.PASSIVE_RETURN_NEVER_SET): dict_ = state.dict manager = state.manager return self._identity_class, tuple([ - manager[self._columntoproperty[col].key]. - impl.get(state, dict_, attributes.PASSIVE_RETURN_NEVER_SET) - for col in self.primary_key - ]) + manager[prop.key]. + impl.get(state, dict_, passive) + for prop in self._identity_key_props + ]), state.identity_token def primary_key_from_instance(self, instance): """Return the list of primary key values for the given @@ -2569,17 +2570,9 @@ class Mapper(InspectionAttr): """ state = attributes.instance_state(instance) - return self._primary_key_from_state(state, attributes.PASSIVE_OFF) - - def _primary_key_from_state( - self, state, passive=attributes.PASSIVE_RETURN_NEVER_SET): - dict_ = state.dict - manager = state.manager - return [ - manager[prop.key]. - impl.get(state, dict_, passive) - for prop in self._identity_key_props - ] + identity_key = self._identity_key_from_state( + state, attributes.PASSIVE_OFF) + return identity_key[1] @_memoized_configured_property def _identity_key_props(self): diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index dfa05a85e9..942ac2b245 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1402,7 +1402,7 @@ class BulkEvaluate(BulkUD): # TODO: detect when the where clause is a trivial primary key match self.matched_objects = [ - obj for (cls, pk), obj in + obj for (cls, pk, identity_token), obj in query.session.identity_map.items() if issubclass(cls, target_cls) and eval_condition(obj)] diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 209bb6d6a5..8668b312b0 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -867,9 +867,10 @@ class Query(object): :return: The object instance, or ``None``. """ - return self._get_impl(ident, loading.load_on_ident) + return self._get_impl( + ident, loading.load_on_ident) - def _get_impl(self, ident, fallback_fn): + def _get_impl(self, ident, fallback_fn, identity_token=None): # convert composite types to individual args if hasattr(ident, '__composite_values__'): ident = ident.__composite_values__() @@ -884,7 +885,8 @@ class Query(object): "primary key for query.get(); primary key columns are %s" % ','.join("'%s'" % c for c in mapper.primary_key)) - key = mapper.identity_key_from_primary_key(ident) + key = mapper.identity_key_from_primary_key( + ident, identity_token=identity_token) if not self._populate_existing and \ not mapper.always_refresh and \ @@ -4127,7 +4129,7 @@ class QueryContext(object): 'eager_joins', 'create_eager_joins', 'propagate_options', 'attributes', 'statement', 'from_clause', 'whereclause', 'order_by', 'labels', '_for_update_arg', 'runid', 'partials', - 'post_load_paths' + 'post_load_paths', 'identity_token' ) def __init__(self, query): @@ -4164,6 +4166,7 @@ class QueryContext(object): self.propagate_options = set(o for o in query._with_options if o.propagate_to_loaders) self.attributes = query._attributes.copy() + self.identity_token = None class AliasOption(interfaces.MapperOption): diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 4964c22e65..04ccf7e3c9 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -63,6 +63,7 @@ class InstanceState(interfaces.InspectionAttr): _load_pending = False _orphaned_outside_of_session = False is_instance = True + identity_token = None callables = () """A namespace where a per-state loader callable can be associated. @@ -462,21 +463,34 @@ class InstanceState(interfaces.InspectionAttr): if 'callables' in state_dict: self.callables = state_dict['callables'] - try: - self.expired_attributes = state_dict['expired_attributes'] - except KeyError: - self.expired_attributes = set() - # 0.9 and earlier compat - for k in list(self.callables): - if self.callables[k] is self: - self.expired_attributes.add(k) - del self.callables[k] + try: + self.expired_attributes = state_dict['expired_attributes'] + except KeyError: + self.expired_attributes = set() + # 0.9 and earlier compat + for k in list(self.callables): + if self.callables[k] is self: + self.expired_attributes.add(k) + del self.callables[k] + else: + if 'expired_attributes' in state_dict: + self.expired_attributes = state_dict['expired_attributes'] + else: + self.expired_attributes = set() self.__dict__.update([ (k, state_dict[k]) for k in ( - 'key', 'load_options', + 'key', 'load_options' ) if k in state_dict ]) + if self.key: + try: + self.identity_token = self.key[2] + except IndexError: + # 1.1 and earlier compat before identity_token + assert len(self.key) == 2 + self.key = self.key + (None, ) + self.identity_token = None if 'load_path' in state_dict: self.load_path = PathRegistry.\ diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 4267b79fb5..ad2a74a34c 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -214,7 +214,7 @@ def identity_key(*args, **kwargs): This function has several call styles: - * ``identity_key(class, ident)`` + * ``identity_key(class, ident, identity_token=token)`` This form receives a mapped class and a primary key scalar or tuple as an argument. @@ -222,10 +222,13 @@ def identity_key(*args, **kwargs): E.g.:: >>> identity_key(MyClass, (1, 2)) - (, (1, 2)) + (, (1, 2), None) :param class: mapped class (must be a positional argument) :param ident: primary key, may be a scalar or tuple argument. + ;param identity_token: optional identity token + + .. versionadded:: 1.2 added identity_token * ``identity_key(instance=instance)`` @@ -239,7 +242,7 @@ def identity_key(*args, **kwargs): >>> instance = MyClass(1, 2) >>> identity_key(instance=instance) - (, (1, 2)) + (, (1, 2), None) In this form, the given instance is ultimately run though :meth:`.Mapper.identity_key_from_instance`, which will have the @@ -248,7 +251,7 @@ def identity_key(*args, **kwargs): :param instance: object instance (must be given as a keyword arg) - * ``identity_key(class, row=row)`` + * ``identity_key(class, row=row, identity_token=token)`` This form is similar to the class/tuple form, except is passed a database result row as a :class:`.RowProxy` object. @@ -258,41 +261,50 @@ def identity_key(*args, **kwargs): >>> row = engine.execute("select * from table where a=1 and b=2").\ first() >>> identity_key(MyClass, row=row) - (, (1, 2)) + (, (1, 2), None) :param class: mapped class (must be a positional argument) :param row: :class:`.RowProxy` row returned by a :class:`.ResultProxy` (must be given as a keyword arg) + ;param identity_token: optional identity token + + .. versionadded:: 1.2 added identity_token """ if args: - if len(args) == 1: + row = None + largs = len(args) + if largs == 1: class_ = args[0] try: row = kwargs.pop("row") except KeyError: ident = kwargs.pop("ident") - elif len(args) == 2: - class_, ident = args - elif len(args) == 3: + elif largs in (2, 3): class_, ident = args else: raise sa_exc.ArgumentError( "expected up to three positional arguments, " - "got %s" % len(args)) + "got %s" % largs) + + identity_token = kwargs.pop("identity_token", None) if kwargs: raise sa_exc.ArgumentError("unknown keyword arguments: %s" % ", ".join(kwargs)) mapper = class_mapper(class_) - if "ident" in locals(): - return mapper.identity_key_from_primary_key(util.to_list(ident)) - return mapper.identity_key_from_row(row) - instance = kwargs.pop("instance") - if kwargs: - raise sa_exc.ArgumentError("unknown keyword arguments: %s" - % ", ".join(kwargs.keys)) - mapper = object_mapper(instance) - return mapper.identity_key_from_instance(instance) + if row is None: + return mapper.identity_key_from_primary_key( + util.to_list(ident), identity_token=identity_token) + else: + return mapper.identity_key_from_row( + row, identity_token=identity_token) + else: + instance = kwargs.pop("instance") + if kwargs: + raise sa_exc.ArgumentError("unknown keyword arguments: %s" + % ", ".join(kwargs.keys)) + mapper = object_mapper(instance) + return mapper.identity_key_from_instance(instance) class ORMAdapter(sql_util.ColumnAdapter): diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index 79487b2a79..c09753e5df 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -125,8 +125,10 @@ class ShardTest(object): self.city = city class Report(object): - def __init__(self, temperature): + def __init__(self, temperature, id_=None): self.temperature = temperature + if id_: + self.id = id_ mapper(WeatherLocation, weather_locations, properties={ 'reports': relationship(Report, backref='location'), @@ -143,8 +145,8 @@ class ShardTest(object): dublin = WeatherLocation('Europe', 'Dublin') brasilia = WeatherLocation('South America', 'Brasila') quito = WeatherLocation('South America', 'Quito') - tokyo.reports.append(Report(80.0)) - newyork.reports.append(Report(75)) + tokyo.reports.append(Report(80.0, id_=1)) + newyork.reports.append(Report(75, id_=1)) quito.reports.append(Report(85)) sess = create_session() for c in [ @@ -157,6 +159,13 @@ class ShardTest(object): quito, ]: sess.add(c) + sess.flush() + + eq_(inspect(newyork).key[2], "north_america") + eq_(inspect(newyork).identity_token, "north_america") + eq_(inspect(dublin).key[2], "europe") + eq_(inspect(dublin).identity_token, "europe") + sess.commit() sess.close() return sess @@ -165,7 +174,7 @@ class ShardTest(object): sess = self._fixture_data() tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one() tokyo.city # reload 'city' attribute on tokyo - sess.expunge_all() + sess.expire_all() eq_(db2.execute(weather_locations.select()).fetchall(), [(1, 'Asia', 'Tokyo')]) eq_(db1.execute(weather_locations.select()).fetchall(), [(2, @@ -186,6 +195,33 @@ class ShardTest(object): eq_(set([c.city for c in asia_and_europe]), set(['Tokyo', 'London', 'Dublin'])) + # inspect the shard token stored with each instance + eq_( + set(inspect(c).key[2] for c in asia_and_europe), + set(['europe', 'asia'])) + + eq_( + set(inspect(c).identity_token for c in asia_and_europe), + set(['europe', 'asia'])) + + newyork = sess.query(WeatherLocation).filter_by(city="New York").one() + newyork_report = newyork.reports[0] + tokyo_report = tokyo.reports[0] + + # same primary key, two identity keys + eq_( + inspect(newyork_report).identity_key, + (Report, (1, ), "north_america") + ) + eq_( + inspect(tokyo_report).identity_key, + (Report, (1, ), "asia") + ) + + # the token representing the originating shard is available + eq_(inspect(newyork_report).identity_token, "north_america") + eq_(inspect(tokyo_report).identity_token, "asia") + def test_get_baked_query(self): sess = self._fixture_data() @@ -201,6 +237,8 @@ class ShardTest(object): t = bq(sess).get(tokyo.id) eq_(t.city, tokyo.city) + eq_(inspect(t).key[2], 'asia') + def test_get_baked_query_shard_id(self): sess = self._fixture_data() @@ -217,6 +255,8 @@ class ShardTest(object): lambda q: q.set_shard("asia")).get(tokyo.id) eq_(t.city, tokyo.city) + eq_(inspect(t).key[2], 'asia') + def test_filter_baked_query_shard_id(self): sess = self._fixture_data() diff --git a/test/orm/test_inspect.py b/test/orm/test_inspect.py index fc061676be..a67ac4419c 100644 --- a/test/orm/test_inspect.py +++ b/test/orm/test_inspect.py @@ -7,6 +7,7 @@ from test.orm import _fixtures from sqlalchemy.orm import class_mapper, synonym, Session, aliased from sqlalchemy.orm.attributes import instance_state, NO_VALUE from sqlalchemy import testing +from sqlalchemy.orm.util import identity_key class TestORMInspection(_fixtures.FixtureTest): @@ -487,7 +488,7 @@ class TestORMInspection(_fixtures.FixtureTest): insp = inspect(u1) eq_( insp.identity_key, - (User, (u1.id, )) + identity_key(User, (u1.id, )) ) def test_persistence_states(self): diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py index 6b925806f5..2369acb960 100644 --- a/test/orm/test_pickled.py +++ b/test/orm/test_pickled.py @@ -290,10 +290,44 @@ class PickleTest(fixtures.MappedTest): manager.class_ = User state_09['manager'] = manager state.__setstate__(state_09) + eq_(state.expired_attributes, {'name', 'id'}) sess = Session() sess.add(inst) eq_(inst.name, 'ed') + # test identity_token expansion + eq_(sa.inspect(inst).key, (User, (1, ), None)) + + def test_11_pickle(self): + users = self.tables.users + mapper(User, users) + sess = Session() + u1 = User(id=1, name='ed') + sess.add(u1) + sess.commit() + + sess.close() + + manager = instrumentation._SerializeManager.__new__( + instrumentation._SerializeManager) + manager.class_ = User + + state_11 = { + + 'class_': User, + 'modified': False, + 'committed_state': {}, + 'instance': u1, + 'manager': manager, + 'key': (User, (1,)), + 'expired_attributes': set(), + 'expired': True} + + state = sa_state.InstanceState.__new__(sa_state.InstanceState) + state.__setstate__(state_11) + + eq_(state.identity_token, None) + eq_(state.identity_key, (User, (1,), None)) @testing.requires.non_broken_pickle def test_options_with_descriptors(self): diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index 1510689f9f..c639418d9d 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -2,6 +2,7 @@ from __future__ import with_statement from sqlalchemy import ( testing, exc as sa_exc, event, String, Column, Table, select, func) from sqlalchemy.sql import elements +from sqlalchemy.orm.util import identity_key from sqlalchemy.testing import ( fixtures, engines, eq_, assert_raises, assert_raises_message, assert_warnings, mock, expect_warnings, is_, is_not_) @@ -12,6 +13,7 @@ from sqlalchemy.testing.util import gc_collect from test.orm._fixtures import FixtureTest from sqlalchemy import inspect + class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): run_inserts = None __backend__ = True @@ -1715,8 +1717,8 @@ class NaturalPKRollbackTest(fixtures.MappedTest): assert u1 in s assert u2 in s - assert s.identity_map[(User, ('u1',))] is u1 - assert s.identity_map[(User, ('u2',))] is u2 + assert s.identity_map[identity_key(User, ('u1',))] is u1 + assert s.identity_map[identity_key(User, ('u2',))] is u2 def test_multiple_key_replaced_by_update(self): users, User = self.tables.users, self.classes.User @@ -1747,9 +1749,9 @@ class NaturalPKRollbackTest(fixtures.MappedTest): assert u2 in s assert u3 in s - assert s.identity_map[(User, ('u1',))] is u1 - assert s.identity_map[(User, ('u2',))] is u2 - assert s.identity_map[(User, ('u3',))] is u3 + assert s.identity_map[identity_key(User, ('u1',))] is u1 + assert s.identity_map[identity_key(User, ('u2',))] is u2 + assert s.identity_map[identity_key(User, ('u3',))] is u3 def test_key_replaced_by_oob_insert(self): users, User = self.tables.users, self.classes.User @@ -1774,4 +1776,4 @@ class NaturalPKRollbackTest(fixtures.MappedTest): assert u1 in s assert u2 not in s - assert s.identity_map[(User, ('u1',))] is u1 + assert s.identity_map[identity_key(User, ('u1',))] is u1 diff --git a/test/orm/test_utils.py b/test/orm/test_utils.py index b06884b867..44161ddcde 100644 --- a/test/orm/test_utils.py +++ b/test/orm/test_utils.py @@ -301,9 +301,9 @@ class IdentityKeyTest(_fixtures.FixtureTest): mapper(User, users) key = orm_util.identity_key(User, [1]) - eq_(key, (User, (1,))) + eq_(key, (User, (1,), None)) key = orm_util.identity_key(User, ident=[1]) - eq_(key, (User, (1,))) + eq_(key, (User, (1,), None)) def test_identity_key_scalar(self): User, users = self.classes.User, self.tables.users @@ -311,9 +311,9 @@ class IdentityKeyTest(_fixtures.FixtureTest): mapper(User, users) key = orm_util.identity_key(User, 1) - eq_(key, (User, (1,))) + eq_(key, (User, (1,), None)) key = orm_util.identity_key(User, ident=1) - eq_(key, (User, (1,))) + eq_(key, (User, (1,), None)) def test_identity_key_2(self): users, User = self.tables.users, self.classes.User @@ -324,7 +324,7 @@ class IdentityKeyTest(_fixtures.FixtureTest): s.add(u) s.flush() key = orm_util.identity_key(instance=u) - eq_(key, (User, (u.id,))) + eq_(key, (User, (u.id,), None)) def test_identity_key_3(self): User, users = self.classes.User, self.tables.users @@ -333,7 +333,17 @@ class IdentityKeyTest(_fixtures.FixtureTest): row = {users.c.id: 1, users.c.name: "Frank"} key = orm_util.identity_key(User, row=row) - eq_(key, (User, (1,))) + eq_(key, (User, (1,), None)) + + def test_identity_key_token(self): + User, users = self.classes.User, self.tables.users + + mapper(User, users) + + key = orm_util.identity_key(User, [1], identity_token="token") + eq_(key, (User, (1,), "token")) + key = orm_util.identity_key(User, ident=[1], identity_token="token") + eq_(key, (User, (1,), "token")) class PathRegistryTest(_fixtures.FixtureTest):