]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Use identity_token for refresh(), unexpire, undefer
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 May 2018 16:35:23 +0000 (12:35 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 11 May 2018 14:41:02 +0000 (10:41 -0400)
The horizontal sharding extension now makes use of the identity token
added to ORM identity keys as part of :ticket:`4137`, when an object
refresh or column-based deferred load or unexpiration operation occurs.
Since we know the "shard" that the object originated from, we make
use of this value when refreshing, thereby avoiding queries against
other shards that don't match this object's identity in any case.

Change-Id: Ib91637a65d94ace7405998b8410d62944a83f2eb
Fixes: #4247
(cherry picked from commit 4b71933489cae21ad94b71b0bc7271c075ad0dda)

doc/build/changelog/unreleased_12/4247.rst [new file with mode: 0644]
lib/sqlalchemy/ext/horizontal_shard.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/query.py
test/ext/test_horizontal_shard.py

diff --git a/doc/build/changelog/unreleased_12/4247.rst b/doc/build/changelog/unreleased_12/4247.rst
new file mode 100644 (file)
index 0000000..f1858b9
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+       :tags: bug, ext
+       :tickets: 4247
+
+       The horizontal sharding extension now makes use of the identity token
+       added to ORM identity keys as part of :ticket:`4137`, when an object
+       refresh or column-based deferred load or unexpiration operation occurs.
+       Since we know the "shard" that the object originated from, we make
+       use of this value when refreshing, thereby avoiding queries against
+       other shards that don't match this object's identity in any case.
\ No newline at end of file
index 266bd784ed5792b041e4eac4ef281fc838cf0b90..6516950edfa2601c90d83c653121c89fec571cc8 100644 (file)
@@ -51,7 +51,9 @@ class ShardedQuery(Query):
                 self._params)
             return self.instances(result, context)
 
-        if self._shard_id is not None:
+        if context.identity_token is not None:
+            return iter_for_shard(context.identity_token)
+        elif self._shard_id is not None:
             return iter_for_shard(self._shard_id)
         else:
             partial = []
index 1728b2d37813d04cdacdc8c09f8dc8aad6ebecf5..a169845d4a67d4bf4b2e3b731f6ef6452b0cb9d6 100644 (file)
@@ -177,19 +177,21 @@ def load_on_ident(query, key,
 
     if key is not None:
         ident = key[1]
+        identity_token = key[2]
     else:
-        ident = None
+        ident = identity_token = None
 
     return load_on_pk_identity(
         query, ident, refresh_state=refresh_state,
         with_for_update=with_for_update,
-        only_load_props=only_load_props
+        only_load_props=only_load_props,
+        identity_token=identity_token
     )
 
 
 def load_on_pk_identity(query, primary_key_identity,
                         refresh_state=None, with_for_update=None,
-                        only_load_props=None):
+                        only_load_props=None, identity_token=None):
 
     """Load the given primary key identity from the database."""
 
@@ -240,7 +242,8 @@ def load_on_pk_identity(query, primary_key_identity,
         populate_existing=bool(refresh_state),
         version_check=version_check,
         only_load_props=only_load_props,
-        refresh_state=refresh_state)
+        refresh_state=refresh_state,
+        identity_token=identity_token)
     q._order_by = None
 
     try:
index a17f590e9dc679bccbc78a84c1fd48203ea9371b..1c5a3bc6e38055c1b8a5ce1b7af7caaaa73e6fd3 100644 (file)
@@ -93,6 +93,7 @@ class Query(object):
     _autoflush = True
     _only_load_props = None
     _refresh_state = None
+    _refresh_identity_token = None
     _from_obj = ()
     _join_entities = ()
     _select_from_entity = None
@@ -439,7 +440,8 @@ class Query(object):
     def _get_options(self, populate_existing=None,
                      version_check=None,
                      only_load_props=None,
-                     refresh_state=None):
+                     refresh_state=None,
+                     identity_token=None):
         if populate_existing:
             self._populate_existing = populate_existing
         if version_check:
@@ -448,6 +450,8 @@ class Query(object):
             self._refresh_state = refresh_state
         if only_load_props:
             self._only_load_props = set(only_load_props)
+        if identity_token:
+            self._refresh_identity_token = identity_token
         return self
 
     def _clone(self):
@@ -4219,7 +4223,10 @@ class QueryContext(object):
         self.propagate_options = set(o for o in query._with_options if
                                      o.propagate_to_loaders)
         self.attributes = query._attributes.copy()
-        self.identity_token = None
+        if self.refresh_state is not None:
+            self.identity_token = query._refresh_identity_token
+        else:
+            self.identity_token = None
 
 
 class AliasOption(interfaces.MapperOption):
index 0bcacad378f621d5321d5ae0def77ab72bb570d9..4b37cbd16c9c5028c854631a66d9f29f5d12cd08 100644 (file)
@@ -373,6 +373,57 @@ class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest):
         result = session.query(Book).options(selectinload('pages')).all()
         eq_(result, [book])
 
+class RefreshDeferExpireTest(fixtures.DeclarativeMappedTest):
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class A(Base):
+            __tablename__ = 'a'
+            id = Column(Integer, primary_key=True)
+            data = Column(String(30))
+            deferred_data = deferred(Column(String(30)))
+
+    @classmethod
+    def insert_data(cls):
+        A = cls.classes.A
+        s = Session()
+        s.add(A(data='d1', deferred_data='d2'))
+        s.commit()
+
+    def _session_fixture(self):
+
+        return ShardedSession(
+            shards={
+                "main": testing.db,
+            },
+            shard_chooser=lambda *args: 'main',
+            id_chooser=lambda *args: ['fake', 'main'],
+            query_chooser=lambda *args: ['fake', 'main']
+        )
+
+    def test_refresh(self):
+        A = self.classes.A
+        session = self._session_fixture()
+        a1 = session.query(A).set_shard("main").first()
+
+        session.refresh(a1)
+
+    def test_deferred(self):
+        A = self.classes.A
+        session = self._session_fixture()
+        a1 = session.query(A).set_shard("main").first()
+
+        eq_(a1.deferred_data, "d2")
+
+    def test_unexpire(self):
+        A = self.classes.A
+        session = self._session_fixture()
+        a1 = session.query(A).set_shard("main").first()
+
+        session.expire(a1)
+        eq_(a1.data, "d1")
+
 
 class LazyLoadFromIdentityMapTest(fixtures.DeclarativeMappedTest):
     @classmethod