]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Horizontal shard query places 'shard_id' in
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Jan 2011 22:03:19 +0000 (17:03 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Jan 2011 22:03:19 +0000 (17:03 -0500)
context.attributes where it's accessible by the
"load()" event. [ticket:2031]

CHANGES
lib/sqlalchemy/ext/horizontal_shard.py
test/ext/test_horizontal_shard.py

diff --git a/CHANGES b/CHANGES
index c9dee94b21b14fddccd4f7ce6ae2df1f83a02ac8..0593ba3743f34db8957f650a6afd2c26cb7a6b05 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -87,6 +87,10 @@ CHANGES
 
   - ScopedSession.mapper is removed (deprecated since 0.5).
 
+  - Horizontal shard query places 'shard_id' in 
+    context.attributes where it's accessible by the 
+    "load()" event. [ticket:2031]
+
 - sql
   - LIMIT/OFFSET clauses now use bind parameters
     [ticket:805]
index 41fae8e7bcb1d9ae2de24b24159e4fa7d41b0826..926c029d0e5e568f0818240a852ba5aeb9dcf139 100644 (file)
@@ -95,6 +95,7 @@ class ShardedQuery(Query):
 
     def _execute_and_instances(self, context):
         if self._shard_id is not None:
+            context.attributes['shard_id'] = self._shard_id
             result = self.session.connection(
                             mapper=self._mapper_zero(),
                             shard_id=self._shard_id).execute(context.statement, self._params)
@@ -102,6 +103,7 @@ class ShardedQuery(Query):
         else:
             partial = []
             for shard_id in self.query_chooser(self):
+                context.attributes['shard_id'] = shard_id
                 result = self.session.connection(
                             mapper=self._mapper_zero(),
                             shard_id=shard_id).execute(context.statement, self._params)
index f2b56bca899f1bd4a16d0fc2341ed6d83736b6e6..45903a0e6a796858beb45cfa70f9d91e595147f0 100644 (file)
@@ -1,5 +1,6 @@
 import datetime, os
 from sqlalchemy import *
+from sqlalchemy import event
 from sqlalchemy import sql
 from sqlalchemy.orm import *
 from sqlalchemy.ext.horizontal_shard import ShardedSession
@@ -11,8 +12,7 @@ from nose import SkipTest
 # TODO: ShardTest can be turned into a base for further subclasses
 
 class ShardTest(TestBase):
-    @classmethod
-    def setup_class(cls):
+    def setUp(self):
         global db1, db2, db3, db4, weather_locations, weather_reports
 
         try:
@@ -57,11 +57,12 @@ class ShardTest(TestBase):
 
         db1.execute(ids.insert(), nextid=1)
 
-        cls.setup_session()
-        cls.setup_mappers()
+        self.setup_session()
+        self.setup_mappers()
+
+    def tearDown(self):
+        clear_mappers()
 
-    @classmethod
-    def teardown_class(cls):
         for db in (db1, db2, db3, db4):
             db.connect().invalidate()
         for i in range(1,5):
@@ -101,8 +102,8 @@ class ShardTest(TestBase):
                             for bind in binary.right.clauses:
                                 ids.append(shard_lookup[bind.value])
 
-
-            FindContinent().traverse(query._criterion)
+            if query._criterion is not None:
+                FindContinent().traverse(query._criterion)
             if len(ids) == 0:
                 return ['north_america', 'asia', 'europe',
                         'south_america']
@@ -139,8 +140,7 @@ class ShardTest(TestBase):
         })
 
         mapper(Report, weather_reports)
-
-    def test_roundtrip(self):
+    def _fixture_data(self):
         tokyo = WeatherLocation('Asia', 'Tokyo')
         newyork = WeatherLocation('North America', 'New York')
         toronto = WeatherLocation('North America', 'Toronto')
@@ -163,6 +163,11 @@ class ShardTest(TestBase):
             ]:
             sess.add(c)
         sess.commit()
+        return sess
+
+    def test_roundtrip(self):
+        sess = self._fixture_data()
+        tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one()
         tokyo.city  # reload 'city' attribute on tokyo
         sess.expunge_all()
         eq_(db2.execute(weather_locations.select()).fetchall(), [(1,
@@ -186,3 +191,20 @@ class ShardTest(TestBase):
         eq_(set([c.city for c in asia_and_europe]), set(['Tokyo',
             'London', 'Dublin']))
 
+    def test_shard_id_event(self):
+        canary = []
+        def load(instance, ctx):
+            canary.append(ctx.attributes["shard_id"])
+
+        event.listen(WeatherLocation, "load", load)
+        sess = self._fixture_data()
+
+        tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").set_shard("asia").one()
+
+        sess.query(WeatherLocation).all()
+        eq_(
+            canary, 
+            ['asia', 'north_america', 'north_america', 
+            'europe', 'europe', 'south_america', 
+            'south_america']
+        )
\ No newline at end of file