]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add an identity_token to the identity key
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 14 Dec 2017 15:20:50 +0000 (10:20 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 22 Dec 2017 16:36:53 +0000 (11:36 -0500)
For the purposes of assisting with sharded setups, add a new
member to the identity key that can be customized.  this allows
sharding across databases where the primary key space is shared.

Change-Id: Iae3909f5d4c501b62c10d0371fbceb01abda51db
Fixes: #4137
15 files changed:
doc/build/changelog/migration_12.rst
doc/build/changelog/unreleased_12/4137.rst [new file with mode: 0644]
examples/sharding/attribute_shard.py
lib/sqlalchemy/ext/horizontal_shard.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/util.py
test/ext/test_horizontal_shard.py
test/orm/test_inspect.py
test/orm/test_pickled.py
test/orm/test_transaction.py
test/orm/test_utils.py

index 9c0218c3a0962fdfcfe69491f076365a92d1030f..cd2a49dd49e7221b035748ed281d88260125485d 100644 (file)
@@ -567,6 +567,55 @@ to query across the two proxies ``A.c_values``, ``AtoB.c_value``:
 
 :ticket:`3769`
 
+.. _change_4137:
+
+Identity key enhancements to support sharding
+---------------------------------------------
+
+The identity key structure used by the ORM now contains an additional
+member, so that two identical primary keys that originate from different
+contexts can co-exist within the same identity map.
+
+The example at :ref:`examples_sharding` has been updated to illustrate this
+behavior.  The example shows a sharded class ``WeatherLocation`` that
+refers to a dependent ``WeatherReport`` object, where the ``WeatherReport``
+class is mapped to a table that stores a simple integer primary key.  Two
+``WeatherReport`` objects from different databases may have the same
+primary key value.   The example now illustrates that a new ``identity_token``
+field tracks this difference so that the two objects can co-exist in the
+same identity map::
+
+    tokyo = WeatherLocation('Asia', 'Tokyo')
+    newyork = WeatherLocation('North America', 'New York')
+
+    tokyo.reports.append(Report(80.0))
+    newyork.reports.append(Report(75))
+
+    sess = create_session()
+
+    sess.add_all([tokyo, newyork, quito])
+
+    sess.commit()
+
+    # the Report class uses a simple integer primary key.  So across two
+    # databases, a primary key will be repeated.  The "identity_token" tracks
+    # in memory that these two identical primary keys are local to different
+    # databases.
+
+    newyork_report = newyork.reports[0]
+    tokyo_report = tokyo.reports[0]
+
+    assert inspect(newyork_report).identity_key == (Report, (1, ), "north_america")
+    assert inspect(tokyo_report).identity_key == (Report, (1, ), "asia")
+
+    # the token representing the originating shard is also available directly
+
+    assert inspect(newyork_report).identity_token == "north_america"
+    assert inspect(tokyo_report).identity_token == "asia"
+
+
+:ticket:`4137`
+
 New Features and Improvements - Core
 ====================================
 
diff --git a/doc/build/changelog/unreleased_12/4137.rst b/doc/build/changelog/unreleased_12/4137.rst
new file mode 100644 (file)
index 0000000..c619914
--- /dev/null
@@ -0,0 +1,17 @@
+.. change::
+    :tags: orm, feature
+    :tickets: 4137
+
+    Added a new data member to the identity key tuple
+    used by the ORM's identity map, known as the
+    "identity_token".  This token defaults to None but
+    may be used by database sharding schemes to differentiate
+    objects in memory with the same primary key that come
+    from different databases.   The horizontal sharding
+    extension integrates this token applying the shard
+    identifier to it, thus allowing primary keys to be
+    duplicated across horizontally sharded backends.
+
+    .. seealso::
+
+        :ref:`change_4137`
\ No newline at end of file
index 4ce8c247f4ea12da49af64f8c9ea22bcf6391faf..cd9b14d5e2a77679df7eaabac536345c2ef20433 100644 (file)
@@ -1,14 +1,13 @@
-
-# step 1. imports
-from sqlalchemy import (create_engine, MetaData, Table, Column, Integer,
-    String, ForeignKey, Float, DateTime, event)
-from sqlalchemy.orm import sessionmaker, mapper, relationship
+from sqlalchemy import (create_engine, Table, Column, Integer,
+                        String, ForeignKey, Float, DateTime)
+from sqlalchemy.orm import sessionmaker, relationship
 from sqlalchemy.ext.horizontal_shard import ShardedSession
 from sqlalchemy.sql import operators, visitors
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy import inspect
 
 import datetime
 
-# step 2. databases.
 # db1 is used for id generation. The "pool_threadlocal"
 # causes the id_generator() to use the same connection as that
 # of an ongoing transaction within db1.
@@ -19,61 +18,79 @@ db3 = create_engine('sqlite://', echo=echo)
 db4 = create_engine('sqlite://', echo=echo)
 
 
-# step 3. create session function.  this binds the shard ids
+# create session function.  this binds the shard ids
 # to databases within a ShardedSession and returns it.
 create_session = sessionmaker(class_=ShardedSession)
 
 create_session.configure(shards={
-    'north_america':db1,
-    'asia':db2,
-    'europe':db3,
-    'south_america':db4
+    'north_america': db1,
+    'asia': db2,
+    'europe': db3,
+    'south_america': db4
 })
 
 
-# step 4.  table setup.
-meta = MetaData()
+# mappings and tables
+Base = declarative_base()
 
-# we need a way to create identifiers which are unique across all
-# databases.  one easy way would be to just use a composite primary key, where one
-# value is the shard id.  but here, we'll show something more "generic", an
-# id generation function.  we'll use a simplistic "id table" stored in database
-# #1.  Any other method will do just as well; UUID, hilo, application-specific, etc.
+# we need a way to create identifiers which are unique across all databases.
+# one easy way would be to just use a composite primary key, where  one value
+# is the shard id.  but here, we'll show something more "generic", an id
+# generation function.  we'll use a simplistic "id table" stored in database
+# #1.  Any other method will do just as well; UUID, hilo, application-specific,
+# etc.
 
-ids = Table('ids', meta,
+ids = Table(
+    'ids', Base.metadata,
     Column('nextid', Integer, nullable=False))
 
+
 def id_generator(ctx):
     # in reality, might want to use a separate transaction for this.
-    c = db1.connect()
-    nextid = c.execute(ids.select(for_update=True)).scalar()
-    c.execute(ids.update(values={ids.c.nextid : ids.c.nextid + 1}))
+    with db1.connect() as conn:
+        nextid = conn.scalar(ids.select(for_update=True))
+        conn.execute(ids.update(values={ids.c.nextid: ids.c.nextid + 1}))
     return nextid
 
-# table setup.  we'll store a lead table of continents/cities,
-# and a secondary table storing locations.
-# a particular row will be placed in the database whose shard id corresponds to the
-# 'continent'.  in this setup, secondary rows in 'weather_reports' will
-# be placed in the same DB as that of the parent, but this can be changed
-# if you're willing to write more complex sharding functions.
-
-weather_locations = Table("weather_locations", meta,
-        Column('id', Integer, primary_key=True, default=id_generator),
-        Column('continent', String(30), nullable=False),
-        Column('city', String(50), nullable=False)
-    )
-
-weather_reports = Table("weather_reports", meta,
-    Column('id', Integer, primary_key=True),
-    Column('location_id', Integer, ForeignKey('weather_locations.id')),
-    Column('temperature', Float),
-    Column('report_time', DateTime, default=datetime.datetime.now),
-)
+# table setup.  we'll store a lead table of continents/cities, and a secondary
+# table storing locations. a particular row will be placed in the database
+# whose shard id corresponds to the 'continent'.  in this setup, secondary rows
+# in 'weather_reports' will be placed in the same DB as that of the parent, but
+# this can be changed if you're willing to write more complex sharding
+# functions.
+
+
+class WeatherLocation(Base):
+    __tablename__ = "weather_locations"
+
+    id = Column(Integer, primary_key=True, default=id_generator)
+    continent = Column(String(30), nullable=False)
+    city = Column(String(50), nullable=False)
+
+    reports = relationship("Report", backref='location')
+
+    def __init__(self, continent, city):
+        self.continent = continent
+        self.city = city
+
+
+class Report(Base):
+    __tablename__ = "weather_reports"
+
+    id = Column(Integer, primary_key=True)
+    location_id = Column(
+        'location_id', Integer, ForeignKey('weather_locations.id'))
+    temperature = Column('temperature', Float)
+    report_time = Column(
+        'report_time', DateTime, default=datetime.datetime.now)
+
+    def __init__(self, temperature):
+        self.temperature = temperature
 
 # create tables
 for db in (db1, db2, db3, db4):
-    meta.drop_all(db)
-    meta.create_all(db)
+    Base.metadata.drop_all(db)
+    Base.metadata.create_all(db)
 
 # establish initial "id" in db1
 db1.execute(ids.insert(), nextid=1)
@@ -84,12 +101,13 @@ db1.execute(ids.insert(), nextid=1)
 # we'll use a straight mapping of a particular set of "country"
 # attributes to shard id.
 shard_lookup = {
-    'North America':'north_america',
-    'Asia':'asia',
-    'Europe':'europe',
-    'South America':'south_america'
+    'North America': 'north_america',
+    'Asia': 'asia',
+    'Europe': 'europe',
+    'South America': 'south_america'
 }
 
+
 def shard_chooser(mapper, instance, clause=None):
     """shard chooser.
 
@@ -104,6 +122,7 @@ def shard_chooser(mapper, instance, clause=None):
     else:
         return shard_chooser(mapper, instance.location)
 
+
 def id_chooser(query, ident):
     """id chooser.
 
@@ -116,6 +135,7 @@ def id_chooser(query, ident):
     """
     return ['north_america', 'asia', 'europe', 'south_america']
 
+
 def query_chooser(query):
     """query chooser.
 
@@ -133,9 +153,9 @@ def query_chooser(query):
         # statement column, adjusting for any annotations present.
         # (an annotation is an internal clone of a Column object
         # and occur when using ORM-mapped attributes like
-        # "WeatherLocation.continent"). A simpler comparison, though less accurate,
-        # would be "column.key == 'continent'".
-        if column.shares_lineage(weather_locations.c.continent):
+        # "WeatherLocation.continent"). A simpler comparison, though less
+        # accurate, would be "column.key == 'continent'".
+        if column.shares_lineage(WeatherLocation.__table__.c.continent):
             if operator == operators.eq:
                 ids.append(shard_lookup[value])
             elif operator == operators.in_op:
@@ -146,6 +166,7 @@ def query_chooser(query):
     else:
         return ids
 
+
 def _get_query_comparisons(query):
     """Search an orm.Query object for binary expressions.
 
@@ -185,65 +206,39 @@ def _get_query_comparisons(query):
                 binary.operator == operators.in_op and \
                 hasattr(binary.right, 'clauses'):
             comparisons.append(
-                (binary.left, binary.operator,
+                (
+                    binary.left, binary.operator,
                     tuple(binds[bind] for bind in binary.right.clauses)
                 )
             )
         elif binary.left in clauses and binary.right in binds:
             comparisons.append(
-                (binary.left, binary.operator,binds[binary.right])
+                (binary.left, binary.operator, binds[binary.right])
             )
 
         elif binary.left in binds and binary.right in clauses:
             comparisons.append(
-                (binary.right, binary.operator,binds[binary.left])
+                (binary.right, binary.operator, binds[binary.left])
             )
 
     # here we will traverse through the query's criterion, searching
     # for SQL constructs.  We will place simple column comparisons
     # into a list.
     if query._criterion is not None:
-        visitors.traverse_depthfirst(query._criterion, {},
-                    {'bindparam':visit_bindparam,
-                        'binary':visit_binary,
-                        'column':visit_column
-                    }
+        visitors.traverse_depthfirst(
+            query._criterion, {},
+            {'bindparam': visit_bindparam,
+                'binary': visit_binary,
+                'column': visit_column}
         )
     return comparisons
 
 # further configure create_session to use these functions
 create_session.configure(
-                    shard_chooser=shard_chooser,
-                    id_chooser=id_chooser,
-                    query_chooser=query_chooser
-                    )
-
-# step 6.  mapped classes.
-class WeatherLocation(object):
-    def __init__(self, continent, city):
-        self.continent = continent
-        self.city = city
-
-class Report(object):
-    def __init__(self, temperature):
-        self.temperature = temperature
-
-# step 7.  mappers
-mapper(WeatherLocation, weather_locations, properties={
-    'reports':relationship(Report, backref='location')
-})
-
-mapper(Report, weather_reports)
-
-# step 8 (optional), events.  The "shard_id" is placed
-# in the QueryContext where it can be intercepted and associated
-# with objects, if needed.
-
-def add_shard_id(instance, ctx):
-    instance.shard_id = ctx.attributes["shard_id"]
-
-event.listen(WeatherLocation, "load", add_shard_id)
-event.listen(Report, "load", add_shard_id)
+    shard_chooser=shard_chooser,
+    id_chooser=id_chooser,
+    query_chooser=query_chooser
+)
 
 # save and load objects!
 
@@ -260,21 +255,33 @@ newyork.reports.append(Report(75))
 quito.reports.append(Report(85))
 
 sess = create_session()
-for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
-    sess.add(c)
-sess.commit()
 
-tokyo_id = tokyo.id
+sess.add_all([tokyo, newyork, toronto, london, dublin, brasilia, quito])
 
-sess.close()
+sess.commit()
 
-t = sess.query(WeatherLocation).get(tokyo_id)
+t = sess.query(WeatherLocation).get(tokyo.id)
 assert t.city == tokyo.city
 assert t.reports[0].temperature == 80.0
 
-north_american_cities = sess.query(WeatherLocation).filter(WeatherLocation.continent == 'North America')
-assert [c.city for c in north_american_cities] == ['New York', 'Toronto']
+north_american_cities = sess.query(WeatherLocation).filter(
+    WeatherLocation.continent == 'North America')
+assert {c.city for c in north_american_cities} == {'New York', 'Toronto'}
+
+asia_and_europe = sess.query(WeatherLocation).filter(
+    WeatherLocation.continent.in_(['Europe', 'Asia']))
+assert {c.city for c in asia_and_europe} == {'Tokyo', 'London', 'Dublin'}
+
+# the Report class uses a simple integer primary key.  So across two databases,
+# a primary key will be repeated.  The "identity_token" tracks in memory
+# that these two identical primary keys are local to different databases.
+newyork_report = newyork.reports[0]
+tokyo_report = tokyo.reports[0]
+
+assert inspect(newyork_report).identity_key == (Report, (1, ), "north_america")
+assert inspect(tokyo_report).identity_key == (Report, (1, ), "asia")
 
-asia_and_europe = sess.query(WeatherLocation).filter(WeatherLocation.continent.in_(['Europe', 'Asia']))
-assert set([c.city for c in asia_and_europe]) == set(['Tokyo', 'London', 'Dublin'])
+# the token representing the originating shard is also available directly
 
+assert inspect(newyork_report).identity_token == "north_america"
+assert inspect(tokyo_report).identity_token == "asia"
index 8902ae6065c94a3d38cbe92683b56a0017b1e9af..c5cf98b40416698de4c89ceb993ce2a666a4ecfb 100644 (file)
@@ -15,6 +15,7 @@ the source distribution.
 
 """
 
+from .. import inspect
 from .. import util
 from ..orm.session import Session
 from ..orm.query import Query
@@ -42,7 +43,7 @@ class ShardedQuery(Query):
 
     def _execute_and_instances(self, context):
         def iter_for_shard(shard_id):
-            context.attributes['shard_id'] = shard_id
+            context.attributes['shard_id'] = context.identity_token = shard_id
             result = self._connection_from_session(
                 mapper=self._mapper_zero(),
                 shard_id=shard_id).execute(
@@ -62,6 +63,9 @@ class ShardedQuery(Query):
             return iter(partial)
 
     def _get_impl(self, ident, fallback_fn):
+        # TODO: the "ident" here should be getting the identity token
+        # which indicates that this area can likely be simplified, as the
+        # token will fall through into _execute_and_instances
         def _fallback(query, ident):
             if self._shard_id is not None:
                 return fallback_fn(self, ident)
@@ -75,7 +79,13 @@ class ShardedQuery(Query):
                 else:
                     return None
 
-        return super(ShardedQuery, self)._get_impl(ident, _fallback)
+        if self._shard_id is not None:
+            identity_token = self._shard_id
+        else:
+            identity_token = None
+
+        return super(ShardedQuery, self)._get_impl(
+            ident, _fallback, identity_token=identity_token)
 
 
 class ShardedSession(Session):
@@ -112,9 +122,24 @@ class ShardedSession(Session):
             for k in shards:
                 self.bind_shard(k, shards[k])
 
+    def _choose_shard_and_assign(self, mapper, instance, **kw):
+        if instance is not None:
+            state = inspect(instance)
+            if state.key:
+                token = state.key[2]
+                assert token is not None
+                return token
+            elif state.identity_token:
+                return state.identity_token
+
+        shard_id = self.shard_chooser(mapper, instance, **kw)
+        if instance is not None:
+            state.identity_token = shard_id
+        return shard_id
+
     def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
         if shard_id is None:
-            shard_id = self.shard_chooser(mapper, instance)
+            shard_id = self._choose_shard_and_assign(mapper, instance)
 
         if self.transaction is not None:
             return self.transaction.connection(mapper, shard_id=shard_id)
@@ -128,7 +153,8 @@ class ShardedSession(Session):
     def get_bind(self, mapper, shard_id=None,
                  instance=None, clause=None, **kw):
         if shard_id is None:
-            shard_id = self.shard_chooser(mapper, instance, clause=clause)
+            shard_id = self._choose_shard_and_assign(
+                mapper, instance, clause=clause)
         return self.__binds[shard_id]
 
     def bind_shard(self, shard_id, bind):
index a25a1422d5a3e9a6348f063a6474a8514df10283..a23cafac2cc178dc95647177f13b6bc2d75d5f59 100644 (file)
@@ -369,6 +369,7 @@ def _instance_processor(
     session_id = context.session.hash_key
     version_check = context.version_check
     runid = context.runid
+    identity_token = context.identity_token
 
     if not refresh_state and _polymorphic_from is not None:
         key = ('loader', path.path)
@@ -430,7 +431,8 @@ def _instance_processor(
             # session, or we have to create a new one
             identitykey = (
                 identity_class,
-                tuple([row[column] for column in pk_cols])
+                tuple([row[column] for column in pk_cols]),
+                identity_token
             )
 
             instance = session_identity_map.get(identitykey)
@@ -464,6 +466,7 @@ def _instance_processor(
                 dict_ = instance_dict(instance)
                 state = instance_state(instance)
                 state.key = identitykey
+                state.identity_token = identity_token
 
                 # attach instance to session.
                 state.session_id = session_id
index 31a93f42e72943ec0c66ae0a5b478c7e3f9d8c5b..8317c914b2eac0e3f05482bb20484250a962cbb6 100644 (file)
@@ -2506,7 +2506,7 @@ class Mapper(InspectionAttr):
         else:
             return True
 
-    def identity_key_from_row(self, row, adapter=None):
+    def identity_key_from_row(self, row, identity_token=None, adapter=None):
         """Return an identity-map key for use in storing/retrieving an
         item from the identity map.
 
@@ -2522,16 +2522,16 @@ class Mapper(InspectionAttr):
             pk_cols = [adapter.columns[c] for c in pk_cols]
 
         return self._identity_class, \
-            tuple(row[column] for column in pk_cols)
+            tuple(row[column] for column in pk_cols), identity_token
 
-    def identity_key_from_primary_key(self, primary_key):
+    def identity_key_from_primary_key(self, primary_key, identity_token=None):
         """Return an identity-map key for use in storing/retrieving an
         item from an identity map.
 
         :param primary_key: A list of values indicating the identifier.
 
         """
-        return self._identity_class, tuple(primary_key)
+        return self._identity_class, tuple(primary_key), identity_token
 
     def identity_key_from_instance(self, instance):
         """Return the identity key for the given instance, based on
@@ -2546,17 +2546,18 @@ class Mapper(InspectionAttr):
         attribute name `key`.
 
         """
-        return self.identity_key_from_primary_key(
-            self.primary_key_from_instance(instance))
+        state = attributes.instance_state(instance)
+        return self._identity_key_from_state(state, attributes.PASSIVE_OFF)
 
-    def _identity_key_from_state(self, state):
+    def _identity_key_from_state(
+            self, state, passive=attributes.PASSIVE_RETURN_NEVER_SET):
         dict_ = state.dict
         manager = state.manager
         return self._identity_class, tuple([
-            manager[self._columntoproperty[col].key].
-            impl.get(state, dict_, attributes.PASSIVE_RETURN_NEVER_SET)
-            for col in self.primary_key
-        ])
+            manager[prop.key].
+            impl.get(state, dict_, passive)
+            for prop in self._identity_key_props
+        ]), state.identity_token
 
     def primary_key_from_instance(self, instance):
         """Return the list of primary key values for the given
@@ -2569,17 +2570,9 @@ class Mapper(InspectionAttr):
 
         """
         state = attributes.instance_state(instance)
-        return self._primary_key_from_state(state, attributes.PASSIVE_OFF)
-
-    def _primary_key_from_state(
-            self, state, passive=attributes.PASSIVE_RETURN_NEVER_SET):
-        dict_ = state.dict
-        manager = state.manager
-        return [
-            manager[prop.key].
-            impl.get(state, dict_, passive)
-            for prop in self._identity_key_props
-        ]
+        identity_key = self._identity_key_from_state(
+            state, attributes.PASSIVE_OFF)
+        return identity_key[1]
 
     @_memoized_configured_property
     def _identity_key_props(self):
index dfa05a85e915276ec233abeafa65268edc1ddefd..942ac2b245ce8a1d3d4edec89748b6157f43c9dd 100644 (file)
@@ -1402,7 +1402,7 @@ class BulkEvaluate(BulkUD):
 
         # TODO: detect when the where clause is a trivial primary key match
         self.matched_objects = [
-            obj for (cls, pk), obj in
+            obj for (cls, pk, identity_token), obj in
             query.session.identity_map.items()
             if issubclass(cls, target_cls) and
             eval_condition(obj)]
index 209bb6d6a50adf00d78df26416bfbaf70b918b7e..8668b312b0c70d5bfc874a8f5ca04aca71aa2b77 100644 (file)
@@ -867,9 +867,10 @@ class Query(object):
         :return: The object instance, or ``None``.
 
         """
-        return self._get_impl(ident, loading.load_on_ident)
+        return self._get_impl(
+            ident, loading.load_on_ident)
 
-    def _get_impl(self, ident, fallback_fn):
+    def _get_impl(self, ident, fallback_fn, identity_token=None):
         # convert composite types to individual args
         if hasattr(ident, '__composite_values__'):
             ident = ident.__composite_values__()
@@ -884,7 +885,8 @@ class Query(object):
                 "primary key for query.get(); primary key columns are %s" %
                 ','.join("'%s'" % c for c in mapper.primary_key))
 
-        key = mapper.identity_key_from_primary_key(ident)
+        key = mapper.identity_key_from_primary_key(
+            ident, identity_token=identity_token)
 
         if not self._populate_existing and \
                 not mapper.always_refresh and \
@@ -4127,7 +4129,7 @@ class QueryContext(object):
         'eager_joins', 'create_eager_joins', 'propagate_options',
         'attributes', 'statement', 'from_clause', 'whereclause',
         'order_by', 'labels', '_for_update_arg', 'runid', 'partials',
-        'post_load_paths'
+        'post_load_paths', 'identity_token'
     )
 
     def __init__(self, query):
@@ -4164,6 +4166,7 @@ class QueryContext(object):
         self.propagate_options = set(o for o in query._with_options if
                                      o.propagate_to_loaders)
         self.attributes = query._attributes.copy()
+        self.identity_token = None
 
 
 class AliasOption(interfaces.MapperOption):
index 4964c22e6594aa307c2857e343dee8e0ee50cc88..04ccf7e3c990ad34d3739ad91d8df343f84157dc 100644 (file)
@@ -63,6 +63,7 @@ class InstanceState(interfaces.InspectionAttr):
     _load_pending = False
     _orphaned_outside_of_session = False
     is_instance = True
+    identity_token = None
 
     callables = ()
     """A namespace where a per-state loader callable can be associated.
@@ -462,21 +463,34 @@ class InstanceState(interfaces.InspectionAttr):
         if 'callables' in state_dict:
             self.callables = state_dict['callables']
 
-        try:
-            self.expired_attributes = state_dict['expired_attributes']
-        except KeyError:
-            self.expired_attributes = set()
-            # 0.9 and earlier compat
-            for k in list(self.callables):
-                if self.callables[k] is self:
-                    self.expired_attributes.add(k)
-                    del self.callables[k]
+            try:
+                self.expired_attributes = state_dict['expired_attributes']
+            except KeyError:
+                self.expired_attributes = set()
+                # 0.9 and earlier compat
+                for k in list(self.callables):
+                    if self.callables[k] is self:
+                        self.expired_attributes.add(k)
+                        del self.callables[k]
+        else:
+            if 'expired_attributes' in state_dict:
+                self.expired_attributes = state_dict['expired_attributes']
+            else:
+                self.expired_attributes = set()
 
         self.__dict__.update([
             (k, state_dict[k]) for k in (
-                'key', 'load_options',
+                'key', 'load_options'
             ) if k in state_dict
         ])
+        if self.key:
+            try:
+                self.identity_token = self.key[2]
+            except IndexError:
+                # 1.1 and earlier compat before identity_token
+                assert len(self.key) == 2
+                self.key = self.key + (None, )
+                self.identity_token = None
 
         if 'load_path' in state_dict:
             self.load_path = PathRegistry.\
index 4267b79fb5661d70d2854e86bd07339de4f4504c..ad2a74a34cd508d60f6c79dd521c46c300f68386 100644 (file)
@@ -214,7 +214,7 @@ def identity_key(*args, **kwargs):
 
     This function has several call styles:
 
-    * ``identity_key(class, ident)``
+    * ``identity_key(class, ident, identity_token=token)``
 
       This form receives a mapped class and a primary key scalar or
       tuple as an argument.
@@ -222,10 +222,13 @@ def identity_key(*args, **kwargs):
       E.g.::
 
         >>> identity_key(MyClass, (1, 2))
-        (<class '__main__.MyClass'>, (1, 2))
+        (<class '__main__.MyClass'>, (1, 2), None)
 
       :param class: mapped class (must be a positional argument)
       :param ident: primary key, may be a scalar or tuple argument.
+      ;param identity_token: optional identity token
+
+        .. versionadded:: 1.2 added identity_token
 
 
     * ``identity_key(instance=instance)``
@@ -239,7 +242,7 @@ def identity_key(*args, **kwargs):
 
         >>> instance = MyClass(1, 2)
         >>> identity_key(instance=instance)
-        (<class '__main__.MyClass'>, (1, 2))
+        (<class '__main__.MyClass'>, (1, 2), None)
 
       In this form, the given instance is ultimately run though
       :meth:`.Mapper.identity_key_from_instance`, which will have the
@@ -248,7 +251,7 @@ def identity_key(*args, **kwargs):
 
       :param instance: object instance (must be given as a keyword arg)
 
-    * ``identity_key(class, row=row)``
+    * ``identity_key(class, row=row, identity_token=token)``
 
       This form is similar to the class/tuple form, except is passed a
       database result row as a :class:`.RowProxy` object.
@@ -258,41 +261,50 @@ def identity_key(*args, **kwargs):
         >>> row = engine.execute("select * from table where a=1 and b=2").\
 first()
         >>> identity_key(MyClass, row=row)
-        (<class '__main__.MyClass'>, (1, 2))
+        (<class '__main__.MyClass'>, (1, 2), None)
 
       :param class: mapped class (must be a positional argument)
       :param row: :class:`.RowProxy` row returned by a :class:`.ResultProxy`
        (must be given as a keyword arg)
+      ;param identity_token: optional identity token
+
+        .. versionadded:: 1.2 added identity_token
 
     """
     if args:
-        if len(args) == 1:
+        row = None
+        largs = len(args)
+        if largs == 1:
             class_ = args[0]
             try:
                 row = kwargs.pop("row")
             except KeyError:
                 ident = kwargs.pop("ident")
-        elif len(args) == 2:
-            class_, ident = args
-        elif len(args) == 3:
+        elif largs in (2, 3):
             class_, ident = args
         else:
             raise sa_exc.ArgumentError(
                 "expected up to three positional arguments, "
-                "got %s" % len(args))
+                "got %s" % largs)
+
+        identity_token = kwargs.pop("identity_token", None)
         if kwargs:
             raise sa_exc.ArgumentError("unknown keyword arguments: %s"
                                        % ", ".join(kwargs))
         mapper = class_mapper(class_)
-        if "ident" in locals():
-            return mapper.identity_key_from_primary_key(util.to_list(ident))
-        return mapper.identity_key_from_row(row)
-    instance = kwargs.pop("instance")
-    if kwargs:
-        raise sa_exc.ArgumentError("unknown keyword arguments: %s"
-                                   % ", ".join(kwargs.keys))
-    mapper = object_mapper(instance)
-    return mapper.identity_key_from_instance(instance)
+        if row is None:
+            return mapper.identity_key_from_primary_key(
+                util.to_list(ident), identity_token=identity_token)
+        else:
+            return mapper.identity_key_from_row(
+                row, identity_token=identity_token)
+    else:
+        instance = kwargs.pop("instance")
+        if kwargs:
+            raise sa_exc.ArgumentError("unknown keyword arguments: %s"
+                                       % ", ".join(kwargs.keys))
+        mapper = object_mapper(instance)
+        return mapper.identity_key_from_instance(instance)
 
 
 class ORMAdapter(sql_util.ColumnAdapter):
index 79487b2a7975d4287682db99438c41888d9dc4e2..c09753e5dfc2c0bf554092859136c9a6aba65ba3 100644 (file)
@@ -125,8 +125,10 @@ class ShardTest(object):
                 self.city = city
 
         class Report(object):
-            def __init__(self, temperature):
+            def __init__(self, temperature, id_=None):
                 self.temperature = temperature
+                if id_:
+                    self.id = id_
 
         mapper(WeatherLocation, weather_locations, properties={
             'reports': relationship(Report, backref='location'),
@@ -143,8 +145,8 @@ class ShardTest(object):
         dublin = WeatherLocation('Europe', 'Dublin')
         brasilia = WeatherLocation('South America', 'Brasila')
         quito = WeatherLocation('South America', 'Quito')
-        tokyo.reports.append(Report(80.0))
-        newyork.reports.append(Report(75))
+        tokyo.reports.append(Report(80.0, id_=1))
+        newyork.reports.append(Report(75, id_=1))
         quito.reports.append(Report(85))
         sess = create_session()
         for c in [
@@ -157,6 +159,13 @@ class ShardTest(object):
             quito,
         ]:
             sess.add(c)
+        sess.flush()
+
+        eq_(inspect(newyork).key[2], "north_america")
+        eq_(inspect(newyork).identity_token, "north_america")
+        eq_(inspect(dublin).key[2], "europe")
+        eq_(inspect(dublin).identity_token, "europe")
+
         sess.commit()
         sess.close()
         return sess
@@ -165,7 +174,7 @@ class ShardTest(object):
         sess = self._fixture_data()
         tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one()
         tokyo.city  # reload 'city' attribute on tokyo
-        sess.expunge_all()
+        sess.expire_all()
         eq_(db2.execute(weather_locations.select()).fetchall(), [(1,
             'Asia', 'Tokyo')])
         eq_(db1.execute(weather_locations.select()).fetchall(), [(2,
@@ -186,6 +195,33 @@ class ShardTest(object):
         eq_(set([c.city for c in asia_and_europe]), set(['Tokyo',
             'London', 'Dublin']))
 
+        # inspect the shard token stored with each instance
+        eq_(
+            set(inspect(c).key[2] for c in asia_and_europe),
+            set(['europe', 'asia']))
+
+        eq_(
+            set(inspect(c).identity_token for c in asia_and_europe),
+            set(['europe', 'asia']))
+
+        newyork = sess.query(WeatherLocation).filter_by(city="New York").one()
+        newyork_report = newyork.reports[0]
+        tokyo_report = tokyo.reports[0]
+
+        # same primary key, two identity keys
+        eq_(
+            inspect(newyork_report).identity_key,
+            (Report, (1, ), "north_america")
+        )
+        eq_(
+            inspect(tokyo_report).identity_key,
+            (Report, (1, ), "asia")
+        )
+
+        # the token representing the originating shard is available
+        eq_(inspect(newyork_report).identity_token, "north_america")
+        eq_(inspect(tokyo_report).identity_token, "asia")
+
     def test_get_baked_query(self):
         sess = self._fixture_data()
 
@@ -201,6 +237,8 @@ class ShardTest(object):
         t = bq(sess).get(tokyo.id)
         eq_(t.city, tokyo.city)
 
+        eq_(inspect(t).key[2], 'asia')
+
     def test_get_baked_query_shard_id(self):
         sess = self._fixture_data()
 
@@ -217,6 +255,8 @@ class ShardTest(object):
             lambda q: q.set_shard("asia")).get(tokyo.id)
         eq_(t.city, tokyo.city)
 
+        eq_(inspect(t).key[2], 'asia')
+
     def test_filter_baked_query_shard_id(self):
         sess = self._fixture_data()
 
index fc061676bee25ffea2a2066d2b36ebdca0acf041..a67ac4419c84113b3338cd7b868a251e6a7546af 100644 (file)
@@ -7,6 +7,7 @@ from test.orm import _fixtures
 from sqlalchemy.orm import class_mapper, synonym, Session, aliased
 from sqlalchemy.orm.attributes import instance_state, NO_VALUE
 from sqlalchemy import testing
+from sqlalchemy.orm.util import identity_key
 
 
 class TestORMInspection(_fixtures.FixtureTest):
@@ -487,7 +488,7 @@ class TestORMInspection(_fixtures.FixtureTest):
         insp = inspect(u1)
         eq_(
             insp.identity_key,
-            (User, (u1.id, ))
+            identity_key(User, (u1.id, ))
         )
 
     def test_persistence_states(self):
index 6b925806f57faf8604f27ed5f05d796bc90d1ddb..2369acb960b4bd7bb006de6cab816059fb87b883 100644 (file)
@@ -290,10 +290,44 @@ class PickleTest(fixtures.MappedTest):
         manager.class_ = User
         state_09['manager'] = manager
         state.__setstate__(state_09)
+        eq_(state.expired_attributes, {'name', 'id'})
 
         sess = Session()
         sess.add(inst)
         eq_(inst.name, 'ed')
+        # test identity_token expansion
+        eq_(sa.inspect(inst).key, (User, (1, ), None))
+
+    def test_11_pickle(self):
+        users = self.tables.users
+        mapper(User, users)
+        sess = Session()
+        u1 = User(id=1, name='ed')
+        sess.add(u1)
+        sess.commit()
+
+        sess.close()
+
+        manager = instrumentation._SerializeManager.__new__(
+            instrumentation._SerializeManager)
+        manager.class_ = User
+
+        state_11 = {
+
+            'class_': User,
+            'modified': False,
+            'committed_state': {},
+            'instance': u1,
+            'manager': manager,
+            'key': (User, (1,)),
+            'expired_attributes': set(),
+            'expired': True}
+
+        state = sa_state.InstanceState.__new__(sa_state.InstanceState)
+        state.__setstate__(state_11)
+
+        eq_(state.identity_token, None)
+        eq_(state.identity_key, (User, (1,), None))
 
     @testing.requires.non_broken_pickle
     def test_options_with_descriptors(self):
index 1510689f9f2e05e7ae7758bc3f459523374fbf58..c639418d9d66f63872fee21b647fd24f0d469df8 100644 (file)
@@ -2,6 +2,7 @@ from __future__ import with_statement
 from sqlalchemy import (
     testing, exc as sa_exc, event, String, Column, Table, select, func)
 from sqlalchemy.sql import elements
+from sqlalchemy.orm.util import identity_key
 from sqlalchemy.testing import (
     fixtures, engines, eq_, assert_raises, assert_raises_message,
     assert_warnings, mock, expect_warnings, is_, is_not_)
@@ -12,6 +13,7 @@ from sqlalchemy.testing.util import gc_collect
 from test.orm._fixtures import FixtureTest
 from sqlalchemy import inspect
 
+
 class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
     run_inserts = None
     __backend__ = True
@@ -1715,8 +1717,8 @@ class NaturalPKRollbackTest(fixtures.MappedTest):
         assert u1 in s
         assert u2 in s
 
-        assert s.identity_map[(User, ('u1',))] is u1
-        assert s.identity_map[(User, ('u2',))] is u2
+        assert s.identity_map[identity_key(User, ('u1',))] is u1
+        assert s.identity_map[identity_key(User, ('u2',))] is u2
 
     def test_multiple_key_replaced_by_update(self):
         users, User = self.tables.users, self.classes.User
@@ -1747,9 +1749,9 @@ class NaturalPKRollbackTest(fixtures.MappedTest):
         assert u2 in s
         assert u3 in s
 
-        assert s.identity_map[(User, ('u1',))] is u1
-        assert s.identity_map[(User, ('u2',))] is u2
-        assert s.identity_map[(User, ('u3',))] is u3
+        assert s.identity_map[identity_key(User, ('u1',))] is u1
+        assert s.identity_map[identity_key(User, ('u2',))] is u2
+        assert s.identity_map[identity_key(User, ('u3',))] is u3
 
     def test_key_replaced_by_oob_insert(self):
         users, User = self.tables.users, self.classes.User
@@ -1774,4 +1776,4 @@ class NaturalPKRollbackTest(fixtures.MappedTest):
         assert u1 in s
         assert u2 not in s
 
-        assert s.identity_map[(User, ('u1',))] is u1
+        assert s.identity_map[identity_key(User, ('u1',))] is u1
index b06884b867875bce19a44b12bcf4a5d2dbbe39ab..44161ddcdef27a0ea53deef0d885028f04d12c4a 100644 (file)
@@ -301,9 +301,9 @@ class IdentityKeyTest(_fixtures.FixtureTest):
         mapper(User, users)
 
         key = orm_util.identity_key(User, [1])
-        eq_(key, (User, (1,)))
+        eq_(key, (User, (1,), None))
         key = orm_util.identity_key(User, ident=[1])
-        eq_(key, (User, (1,)))
+        eq_(key, (User, (1,), None))
 
     def test_identity_key_scalar(self):
         User, users = self.classes.User, self.tables.users
@@ -311,9 +311,9 @@ class IdentityKeyTest(_fixtures.FixtureTest):
         mapper(User, users)
 
         key = orm_util.identity_key(User, 1)
-        eq_(key, (User, (1,)))
+        eq_(key, (User, (1,), None))
         key = orm_util.identity_key(User, ident=1)
-        eq_(key, (User, (1,)))
+        eq_(key, (User, (1,), None))
 
     def test_identity_key_2(self):
         users, User = self.tables.users, self.classes.User
@@ -324,7 +324,7 @@ class IdentityKeyTest(_fixtures.FixtureTest):
         s.add(u)
         s.flush()
         key = orm_util.identity_key(instance=u)
-        eq_(key, (User, (u.id,)))
+        eq_(key, (User, (u.id,), None))
 
     def test_identity_key_3(self):
         User, users = self.classes.User, self.tables.users
@@ -333,7 +333,17 @@ class IdentityKeyTest(_fixtures.FixtureTest):
 
         row = {users.c.id: 1, users.c.name: "Frank"}
         key = orm_util.identity_key(User, row=row)
-        eq_(key, (User, (1,)))
+        eq_(key, (User, (1,), None))
+
+    def test_identity_key_token(self):
+        User, users = self.classes.User, self.tables.users
+
+        mapper(User, users)
+
+        key = orm_util.identity_key(User, [1], identity_token="token")
+        eq_(key, (User, (1,), "token"))
+        key = orm_util.identity_key(User, ident=[1], identity_token="token")
+        eq_(key, (User, (1,), "token"))
 
 
 class PathRegistryTest(_fixtures.FixtureTest):