]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added hooks for alternate session classes into sessionmaker
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 3 Aug 2007 19:54:16 +0000 (19:54 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 3 Aug 2007 19:54:16 +0000 (19:54 +0000)
- moved shard example/unittest over to sessionmaker

examples/sharding/attribute_shard.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/shard.py
test/orm/sharding/shard.py

index e95b978aec14bfb85db4766c3048b94401595541..6e4732989646229b4c331cb4f98e6b36a2b325f7 100644 (file)
@@ -34,13 +34,15 @@ db4 = create_engine('sqlite:///shard4.db', echo=echo)
 
 # step 3. create session function.  this binds the shard ids
 # to databases within a ShardedSession and returns it.
-def create_session():
-    s = ShardedSession(shard_chooser, id_chooser, query_chooser)
-    s.bind_shard('north_america', db1)
-    s.bind_shard('asia', db2)
-    s.bind_shard('europe', db3)
-    s.bind_shard('south_america', db4)
-    return s
+create_session = sessionmaker(class_=ShardedSession)
+
+create_session.configure(shards={
+    'north_america':db1,
+    'asia':db2,
+    'europe':db3,
+    'south_america':db4
+})
+
 
 # step 4.  table setup.
 meta = MetaData()
@@ -143,6 +145,9 @@ def query_chooser(query):
     else:
         return ids
 
+# further configure create_session to use these functions
+create_session.configure(shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser)
+
 # step 6.  mapped classes.    
 class WeatherLocation(object):
     def __init__(self, continent, city):
index f982da5368c318a6892809bde8c1833f3c96328d..80c1a5b0d1d753141fb49f5d2ecde6032403db41 100644 (file)
@@ -14,14 +14,17 @@ from sqlalchemy.orm.mapper import global_extensions
 
 __all__ = ['Session', 'SessionTransaction']
 
-def sessionmaker(autoflush=True, transactional=True, bind=None, **kwargs):
+def sessionmaker(bind=None, class_=None, autoflush=True, transactional=True, **kwargs):
     """Generate a Session configuration."""
     
     kwargs['bind'] = bind
     kwargs['autoflush'] = autoflush
     kwargs['transactional'] = transactional
 
-    class Sess(Session):
+    if class_ is None:
+        class_ = Session
+        
+    class Sess(class_):
         def __init__(self, **local_kwargs):
             for k in kwargs:
                 local_kwargs.setdefault(k, kwargs[k])
index 26d03372f319ffaf474bc3830ce41d941e5f785e..9d4396d2bb43fabaaa0ce1902d404c265d00a755 100644 (file)
@@ -4,7 +4,7 @@ from sqlalchemy.orm import Query
 __all__ = ['ShardedSession', 'ShardedQuery']
 
 class ShardedSession(Session):
-    def __init__(self, shard_chooser, id_chooser, query_chooser, **kwargs):
+    def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs):
         """construct a ShardedSession.
         
             shard_chooser
@@ -32,6 +32,9 @@ class ShardedSession(Session):
         self.__binds = {}
         self._mapper_flush_opts = {'connection_callable':self.connection}
         self._query_cls = ShardedQuery
+        if shards is not None:
+            for k in shards:
+                self.bind_shard(k, shards[k])
         
     def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
         if shard_id is None:
index faa980cc27f633fe78fc91cd2d8edb6bf276a6f6..c1dd63d6575e72aa367d74baa8c9a41ad9e16818 100644 (file)
@@ -90,14 +90,16 @@ class ShardTest(PersistTest):
                 return ['north_america', 'asia', 'europe', 'south_america']
             else:
                 return ids
-
-        def create_session():
-            s = ShardedSession(shard_chooser, id_chooser, query_chooser)
-            s.bind_shard('north_america', db1)
-            s.bind_shard('asia', db2)
-            s.bind_shard('europe', db3)
-            s.bind_shard('south_america', db4)
-            return s
+        
+        create_session = sessionmaker(class_=ShardedSession, autoflush=True, transactional=True)
+
+        create_session.configure(shards={
+            'north_america':db1,
+            'asia':db2,
+            'europe':db3,
+            'south_america':db4
+        }, shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser)
+        
 
     def setup_mappers(self):
         global WeatherLocation, Report
@@ -133,10 +135,13 @@ class ShardTest(PersistTest):
         sess = create_session()
         for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
             sess.save(c)
-        sess.flush()
+        sess.commit()
 
         sess.clear()
 
+        assert db2.execute(weather_locations.select()).fetchall() == [(1, 'Asia', 'Tokyo')]
+        assert db1.execute(weather_locations.select()).fetchall() == [(2, 'North America', 'New York'), (3, 'North America', 'Toronto')]
+        
         t = sess.query(WeatherLocation).get(tokyo.id)
         assert t.city == tokyo.city
         assert t.reports[0].temperature == 80.0