]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Use identity_token for refresh(), unexpire, undefer
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 May 2018 16:35:23 +0000 (12:35 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 May 2018 16:35:23 +0000 (12:35 -0400)
The horizontal sharding extension now makes use of the identity token
added to ORM identity keys as part of :ticket:`4137`, when an object
refresh or column-based deferred load or unexpiration operation occurs.
Since we know the "shard" that the object originated from, we make
use of this value when refreshing, thereby avoiding queries against
other shards that don't match this object's identity in any case.

Change-Id: Ib91637a65d94ace7405998b8410d62944a83f2eb
Fixes: #4247
doc/build/changelog/unreleased_12/4247.rst [new file with mode: 0644]
lib/sqlalchemy/ext/horizontal_shard.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/query.py
test/ext/test_horizontal_shard.py

diff --git a/doc/build/changelog/unreleased_12/4247.rst b/doc/build/changelog/unreleased_12/4247.rst
new file mode 100644 (file)
index 0000000..f1858b9
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+       :tags: bug, ext
+       :tickets: 4247
+
+       The horizontal sharding extension now makes use of the identity token
+       added to ORM identity keys as part of :ticket:`4137`, when an object
+       refresh or column-based deferred load or unexpiration operation occurs.
+       Since we know the "shard" that the object originated from, we make
+       use of this value when refreshing, thereby avoiding queries against
+       other shards that don't match this object's identity in any case.
\ No newline at end of file
index 266bd784ed5792b041e4eac4ef281fc838cf0b90..6516950edfa2601c90d83c653121c89fec571cc8 100644 (file)
@@ -51,7 +51,9 @@ class ShardedQuery(Query):
                 self._params)
             return self.instances(result, context)
 
-        if self._shard_id is not None:
+        if context.identity_token is not None:
+            return iter_for_shard(context.identity_token)
+        elif self._shard_id is not None:
             return iter_for_shard(self._shard_id)
         else:
             partial = []
index 1728b2d37813d04cdacdc8c09f8dc8aad6ebecf5..a169845d4a67d4bf4b2e3b731f6ef6452b0cb9d6 100644 (file)
@@ -177,19 +177,21 @@ def load_on_ident(query, key,
 
     if key is not None:
         ident = key[1]
+        identity_token = key[2]
     else:
-        ident = None
+        ident = identity_token = None
 
     return load_on_pk_identity(
         query, ident, refresh_state=refresh_state,
         with_for_update=with_for_update,
-        only_load_props=only_load_props
+        only_load_props=only_load_props,
+        identity_token=identity_token
     )
 
 
 def load_on_pk_identity(query, primary_key_identity,
                         refresh_state=None, with_for_update=None,
-                        only_load_props=None):
+                        only_load_props=None, identity_token=None):
 
     """Load the given primary key identity from the database."""
 
@@ -240,7 +242,8 @@ def load_on_pk_identity(query, primary_key_identity,
         populate_existing=bool(refresh_state),
         version_check=version_check,
         only_load_props=only_load_props,
-        refresh_state=refresh_state)
+        refresh_state=refresh_state,
+        identity_token=identity_token)
     q._order_by = None
 
     try:
index ea8371f506112d6786ab922cfd0d7144d4990591..a5f3d01f66683e7bfba1e6cb28b6716ea86f6561 100644 (file)
@@ -93,6 +93,7 @@ class Query(object):
     _autoflush = True
     _only_load_props = None
     _refresh_state = None
+    _refresh_identity_token = None
     _from_obj = ()
     _join_entities = ()
     _select_from_entity = None
@@ -439,7 +440,8 @@ class Query(object):
     def _get_options(self, populate_existing=None,
                      version_check=None,
                      only_load_props=None,
-                     refresh_state=None):
+                     refresh_state=None,
+                     identity_token=None):
         if populate_existing:
             self._populate_existing = populate_existing
         if version_check:
@@ -448,6 +450,8 @@ class Query(object):
             self._refresh_state = refresh_state
         if only_load_props:
             self._only_load_props = set(only_load_props)
+        if identity_token:
+            self._refresh_identity_token = identity_token
         return self
 
     def _clone(self):
@@ -4228,7 +4232,10 @@ class QueryContext(object):
         self.propagate_options = set(o for o in query._with_options if
                                      o.propagate_to_loaders)
         self.attributes = query._attributes.copy()
-        self.identity_token = None
+        if self.refresh_state is not None:
+            self.identity_token = query._refresh_identity_token
+        else:
+            self.identity_token = None
 
 
 class AliasOption(interfaces.MapperOption):
index 0bcacad378f621d5321d5ae0def77ab72bb570d9..4b37cbd16c9c5028c854631a66d9f29f5d12cd08 100644 (file)
@@ -373,6 +373,57 @@ class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest):
         result = session.query(Book).options(selectinload('pages')).all()
         eq_(result, [book])
 
+class RefreshDeferExpireTest(fixtures.DeclarativeMappedTest):
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class A(Base):
+            __tablename__ = 'a'
+            id = Column(Integer, primary_key=True)
+            data = Column(String(30))
+            deferred_data = deferred(Column(String(30)))
+
+    @classmethod
+    def insert_data(cls):
+        A = cls.classes.A
+        s = Session()
+        s.add(A(data='d1', deferred_data='d2'))
+        s.commit()
+
+    def _session_fixture(self):
+
+        return ShardedSession(
+            shards={
+                "main": testing.db,
+            },
+            shard_chooser=lambda *args: 'main',
+            id_chooser=lambda *args: ['fake', 'main'],
+            query_chooser=lambda *args: ['fake', 'main']
+        )
+
+    def test_refresh(self):
+        A = self.classes.A
+        session = self._session_fixture()
+        a1 = session.query(A).set_shard("main").first()
+
+        session.refresh(a1)
+
+    def test_deferred(self):
+        A = self.classes.A
+        session = self._session_fixture()
+        a1 = session.query(A).set_shard("main").first()
+
+        eq_(a1.deferred_data, "d2")
+
+    def test_unexpire(self):
+        A = self.classes.A
+        session = self._session_fixture()
+        a1 = session.query(A).set_shard("main").first()
+
+        session.expire(a1)
+        eq_(a1.data, "d1")
+
 
 class LazyLoadFromIdentityMapTest(fixtures.DeclarativeMappedTest):
     @classmethod