From: Mike Bayer Date: Sun, 28 Dec 2008 19:54:58 +0000 (+0000) Subject: - Fixed shard_id argument on ShardedSession.execute(). X-Git-Tag: rel_0_5_0~59 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=d245397bd2ba562993f1c9b5880f1c818050585c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Fixed shard_id argument on ShardedSession.execute(). [ticket:1072] --- diff --git a/CHANGES b/CHANGES index 0e648b0494..fd3dfd0fa3 100644 --- a/CHANGES +++ b/CHANGES @@ -144,6 +144,9 @@ CHANGES queried against. The column won't be "pulled in" from a subclass or superclass mapper since it's not needed. + + - Fixed shard_id argument on ShardedSession.execute(). + [ticket:1072] - sql - Columns can again contain percent signs within their diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 5a0c3faff9..876471c132 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1101,7 +1101,7 @@ class Query(object): return self._execute_and_instances(context) def _execute_and_instances(self, querycontext): - result = self.session.execute(querycontext.statement, params=self._params, mapper=self._mapper_zero_or_none(), _state=self._refresh_state) + result = self.session.execute(querycontext.statement, params=self._params, mapper=self._mapper_zero_or_none()) return self.instances(result, querycontext) def instances(self, cursor, __context=None): diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 3bc3fb4fc2..3be8dac392 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -690,7 +690,7 @@ class Session(object): self.transaction.prepare() - def connection(self, mapper=None, clause=None, _state=None): + def connection(self, mapper=None, clause=None): """Return the active Connection. Retrieves the ``Connection`` managing the current transaction. Any @@ -712,7 +712,7 @@ class Session(object): Optional, any ``ClauseElement`` """ - return self.__connection(self.get_bind(mapper, clause, _state)) + return self.__connection(self.get_bind(mapper, clause)) def __connection(self, engine, **kwargs): if self.transaction is not None: @@ -720,7 +720,7 @@ class Session(object): else: return engine.contextual_connect(**kwargs) - def execute(self, clause, params=None, mapper=None, _state=None): + def execute(self, clause, params=None, mapper=None, **kw): """Execute a clause within the current transaction. Returns a ``ResultProxy`` of execution results. `autocommit` Sessions @@ -741,21 +741,23 @@ class Session(object): mapper Optional, a ``mapper`` or mapped class - _state - Optional, an instance of a mapped class - + \**kw + Additional keyword arguments are sent to :method:`get_bind()` + which locates a connectable to use for the execution. + Subclasses of :class:`Session` may override this. + """ clause = expression._literal_as_text(clause) - engine = self.get_bind(mapper, clause=clause, _state=_state) + engine = self.get_bind(mapper, clause=clause, **kw) return self.__connection(engine, close_with_result=True).execute( clause, params or {}) - def scalar(self, clause, params=None, mapper=None, _state=None): + def scalar(self, clause, params=None, mapper=None): """Like execute() but return a scalar result.""" - engine = self.get_bind(mapper, clause=clause, _state=_state) + engine = self.get_bind(mapper, clause=clause) return self.__connection(engine, close_with_result=True).scalar( clause, params or {}) @@ -838,7 +840,7 @@ class Session(object): """ self.__binds[table] = bind - def get_bind(self, mapper, clause=None, _state=None): + def get_bind(self, mapper, clause=None): """Return an engine corresponding to the given arguments. All arguments are optional. @@ -849,11 +851,8 @@ class Session(object): clause Optional, A ClauseElement (i.e. select(), text(), etc.) - _state - Optional, SA internal representation of a mapped instance - """ - if mapper is clause is _state is None: + if mapper is clause is None: if self.bind: return self.bind else: @@ -862,16 +861,10 @@ class Session(object): "Connection, and no context was provided to locate " "a binding.") - s_mapper = _state is not None and _state_mapper(_state) or None c_mapper = mapper is not None and _class_to_mapper(mapper) or None # manually bound? if self.__binds: - if s_mapper: - if s_mapper.base_mapper in self.__binds: - return self.__binds[s_mapper.base_mapper] - elif s_mapper.mapped_table in self.__binds: - return self.__binds[s_mapper.mapped_table] if c_mapper: if c_mapper.base_mapper in self.__binds: return self.__binds[c_mapper.base_mapper] @@ -888,8 +881,6 @@ class Session(object): if isinstance(clause, sql.expression.ClauseElement) and clause.bind: return clause.bind - if s_mapper and s_mapper.mapped_table.bind: - return s_mapper.mapped_table.bind if c_mapper and c_mapper.mapped_table.bind: return c_mapper.mapped_table.bind @@ -898,8 +889,6 @@ class Session(object): context.append('mapper %s' % c_mapper) if clause is not None: context.append('SQL expression') - if _state is not None: - context.append('state %r' % _state) raise sa_exc.UnboundExecutionError( "Could not locate a bind configured on %s or this Session" % ( diff --git a/lib/sqlalchemy/orm/shard.py b/lib/sqlalchemy/orm/shard.py index f769b206f4..b59d284c2b 100644 --- a/lib/sqlalchemy/orm/shard.py +++ b/lib/sqlalchemy/orm/shard.py @@ -65,7 +65,7 @@ class ShardedSession(Session): else: return self.get_bind(mapper, shard_id=shard_id, instance=instance).contextual_connect(**kwargs) - def get_bind(self, mapper, shard_id=None, instance=None, clause=None): + def get_bind(self, mapper, shard_id=None, instance=None, clause=None, **kw): if shard_id is None: shard_id = self.shard_chooser(mapper, instance, clause=clause) return self.__binds[shard_id] diff --git a/test/orm/sharding/shard.py b/test/orm/sharding/shard.py index f25d097fd7..63f6932ea8 100644 --- a/test/orm/sharding/shard.py +++ b/test/orm/sharding/shard.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import * from sqlalchemy.orm.shard import ShardedSession from sqlalchemy.sql import operators from testlib import * +from testlib.testing import eq_ # TODO: ShardTest can be turned into a base for further subclasses @@ -142,18 +143,19 @@ class ShardTest(TestBase): tokyo.city # reload 'city' attribute on tokyo sess.clear() - assert db2.execute(weather_locations.select()).fetchall() == [(1, 'Asia', 'Tokyo')] - assert db1.execute(weather_locations.select()).fetchall() == [(2, 'North America', 'New York'), (3, 'North America', 'Toronto')] - + eq_(db2.execute(weather_locations.select()).fetchall(), [(1, 'Asia', 'Tokyo')]) + eq_(db1.execute(weather_locations.select()).fetchall(), [(2, 'North America', 'New York'), (3, 'North America', 'Toronto')]) + eq_(sess.execute(weather_locations.select(), shard_id='asia').fetchall(), [(1, 'Asia', 'Tokyo')]) + t = sess.query(WeatherLocation).get(tokyo.id) - assert t.city == tokyo.city - assert t.reports[0].temperature == 80.0 + eq_(t.city, tokyo.city) + eq_(t.reports[0].temperature, 80.0) north_american_cities = sess.query(WeatherLocation).filter(WeatherLocation.continent == 'North America') - assert set([c.city for c in north_american_cities]) == set(['New York', 'Toronto']) + eq_(set([c.city for c in north_american_cities]), set(['New York', 'Toronto'])) asia_and_europe = sess.query(WeatherLocation).filter(WeatherLocation.continent.in_(['Europe', 'Asia'])) - assert set([c.city for c in asia_and_europe]) == set(['Tokyo', 'London', 'Dublin']) + eq_(set([c.city for c in asia_and_europe]), set(['Tokyo', 'London', 'Dublin']))