]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed shard_id argument on ShardedSession.execute().
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Dec 2008 19:54:58 +0000 (19:54 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Dec 2008 19:54:58 +0000 (19:54 +0000)
[ticket:1072]

CHANGES
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/shard.py
test/orm/sharding/shard.py

diff --git a/CHANGES b/CHANGES
index 0e648b0494017600226e2a7ba139e450823954e1..fd3dfd0fa31a5cadfcdf83a0a230f341d42783db 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -144,6 +144,9 @@ CHANGES
           queried against. The column won't be
           "pulled in" from a subclass or superclass
           mapper since it's not needed.
+
+    - Fixed shard_id argument on ShardedSession.execute().
+      [ticket:1072]
       
 - sql
     - Columns can again contain percent signs within their
index 5a0c3faff96ee30f44139de2a7ccb9dc96c36940..876471c1326e3f29476b197e2b167259ff7dafe3 100644 (file)
@@ -1101,7 +1101,7 @@ class Query(object):
         return self._execute_and_instances(context)
 
     def _execute_and_instances(self, querycontext):
-        result = self.session.execute(querycontext.statement, params=self._params, mapper=self._mapper_zero_or_none(), _state=self._refresh_state)
+        result = self.session.execute(querycontext.statement, params=self._params, mapper=self._mapper_zero_or_none())
         return self.instances(result, querycontext)
 
     def instances(self, cursor, __context=None):
index 3bc3fb4fc2838e488b7c76ce8e16c3ebded12f50..3be8dac3928b8cb0286413f36cb095859551d85f 100644 (file)
@@ -690,7 +690,7 @@ class Session(object):
 
         self.transaction.prepare()
 
-    def connection(self, mapper=None, clause=None, _state=None):
+    def connection(self, mapper=None, clause=None):
         """Return the active Connection.
 
         Retrieves the ``Connection`` managing the current transaction.  Any
@@ -712,7 +712,7 @@ class Session(object):
           Optional, any ``ClauseElement``
 
         """
-        return self.__connection(self.get_bind(mapper, clause, _state))
+        return self.__connection(self.get_bind(mapper, clause))
 
     def __connection(self, engine, **kwargs):
         if self.transaction is not None:
@@ -720,7 +720,7 @@ class Session(object):
         else:
             return engine.contextual_connect(**kwargs)
 
-    def execute(self, clause, params=None, mapper=None, _state=None):
+    def execute(self, clause, params=None, mapper=None, **kw):
         """Execute a clause within the current transaction.
 
         Returns a ``ResultProxy`` of execution results.  `autocommit` Sessions
@@ -741,21 +741,23 @@ class Session(object):
         mapper
           Optional, a ``mapper`` or mapped class
 
-        _state
-          Optional, an instance of a mapped class
-
+        \**kw
+          Additional keyword arguments are sent to :method:`get_bind()`
+          which locates a connectable to use for the execution.
+          Subclasses of :class:`Session` may override this.
+          
         """
         clause = expression._literal_as_text(clause)
 
-        engine = self.get_bind(mapper, clause=clause, _state=_state)
+        engine = self.get_bind(mapper, clause=clause, **kw)
 
         return self.__connection(engine, close_with_result=True).execute(
             clause, params or {})
 
-    def scalar(self, clause, params=None, mapper=None, _state=None):
+    def scalar(self, clause, params=None, mapper=None):
         """Like execute() but return a scalar result."""
 
-        engine = self.get_bind(mapper, clause=clause, _state=_state)
+        engine = self.get_bind(mapper, clause=clause)
 
         return self.__connection(engine, close_with_result=True).scalar(
             clause, params or {})
@@ -838,7 +840,7 @@ class Session(object):
         """
         self.__binds[table] = bind
 
-    def get_bind(self, mapper, clause=None, _state=None):
+    def get_bind(self, mapper, clause=None):
         """Return an engine corresponding to the given arguments.
 
         All arguments are optional.
@@ -849,11 +851,8 @@ class Session(object):
         clause
           Optional, A ClauseElement (i.e. select(), text(), etc.)
 
-        _state
-          Optional, SA internal representation of a mapped instance
-
         """
-        if mapper is clause is _state is None:
+        if mapper is clause is None:
             if self.bind:
                 return self.bind
             else:
@@ -862,16 +861,10 @@ class Session(object):
                     "Connection, and no context was provided to locate "
                     "a binding.")
 
-        s_mapper = _state is not None and _state_mapper(_state) or None
         c_mapper = mapper is not None and _class_to_mapper(mapper) or None
 
         # manually bound?
         if self.__binds:
-            if s_mapper:
-                if s_mapper.base_mapper in self.__binds:
-                    return self.__binds[s_mapper.base_mapper]
-                elif s_mapper.mapped_table in self.__binds:
-                    return self.__binds[s_mapper.mapped_table]
             if c_mapper:
                 if c_mapper.base_mapper in self.__binds:
                     return self.__binds[c_mapper.base_mapper]
@@ -888,8 +881,6 @@ class Session(object):
         if isinstance(clause, sql.expression.ClauseElement) and clause.bind:
             return clause.bind
 
-        if s_mapper and s_mapper.mapped_table.bind:
-            return s_mapper.mapped_table.bind
         if c_mapper and c_mapper.mapped_table.bind:
             return c_mapper.mapped_table.bind
 
@@ -898,8 +889,6 @@ class Session(object):
             context.append('mapper %s' % c_mapper)
         if clause is not None:
             context.append('SQL expression')
-        if _state is not None:
-            context.append('state %r' % _state)
 
         raise sa_exc.UnboundExecutionError(
             "Could not locate a bind configured on %s or this Session" % (
index f769b206f4d77dd4eb270aea6ab876f6524a7046..b59d284c2b67cd04128c173b21963d4577835ff2 100644 (file)
@@ -65,7 +65,7 @@ class ShardedSession(Session):
         else:
             return self.get_bind(mapper, shard_id=shard_id, instance=instance).contextual_connect(**kwargs)
     
-    def get_bind(self, mapper, shard_id=None, instance=None, clause=None):
+    def get_bind(self, mapper, shard_id=None, instance=None, clause=None, **kw):
         if shard_id is None:
             shard_id = self.shard_chooser(mapper, instance, clause=clause)
         return self.__binds[shard_id]
index f25d097fd7624810f1a24645a7283dbccf617d25..63f6932ea83c2cb4fcd73d49bb94d9bc740c3887 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy.orm import *
 from sqlalchemy.orm.shard import ShardedSession
 from sqlalchemy.sql import operators
 from testlib import *
+from testlib.testing import eq_
 
 # TODO: ShardTest can be turned into a base for further subclasses
 
@@ -142,18 +143,19 @@ class ShardTest(TestBase):
         tokyo.city   # reload 'city' attribute on tokyo
         sess.clear()
 
-        assert db2.execute(weather_locations.select()).fetchall() == [(1, 'Asia', 'Tokyo')]
-        assert db1.execute(weather_locations.select()).fetchall() == [(2, 'North America', 'New York'), (3, 'North America', 'Toronto')]
-
+        eq_(db2.execute(weather_locations.select()).fetchall(), [(1, 'Asia', 'Tokyo')])
+        eq_(db1.execute(weather_locations.select()).fetchall(), [(2, 'North America', 'New York'), (3, 'North America', 'Toronto')])
+        eq_(sess.execute(weather_locations.select(), shard_id='asia').fetchall(), [(1, 'Asia', 'Tokyo')])
+        
         t = sess.query(WeatherLocation).get(tokyo.id)
-        assert t.city == tokyo.city
-        assert t.reports[0].temperature == 80.0
+        eq_(t.city, tokyo.city)
+        eq_(t.reports[0].temperature, 80.0)
 
         north_american_cities = sess.query(WeatherLocation).filter(WeatherLocation.continent == 'North America')
-        assert set([c.city for c in north_american_cities]) == set(['New York', 'Toronto'])
+        eq_(set([c.city for c in north_american_cities]), set(['New York', 'Toronto']))
 
         asia_and_europe = sess.query(WeatherLocation).filter(WeatherLocation.continent.in_(['Europe', 'Asia']))
-        assert set([c.city for c in asia_and_europe]) == set(['Tokyo', 'London', 'Dublin'])
+        eq_(set([c.city for c in asia_and_europe]), set(['Tokyo', 'London', 'Dublin']))