From: Ben Darnell Date: Mon, 2 Jan 2012 00:17:35 +0000 (-0800) Subject: Allow exceptions thrown in the first (synchronous) phase of a gen.Task X-Git-Tag: v2.2.0~55 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=008e605ec6431e696158ed65adb89eb85a65aa7f;p=thirdparty%2Ftornado.git Allow exceptions thrown in the first (synchronous) phase of a gen.Task to be caught by the generator. --- diff --git a/tornado/gen.py b/tornado/gen.py index 6f6b49aa5..59684f50f 100644 --- a/tornado/gen.py +++ b/tornado/gen.py @@ -255,6 +255,7 @@ class Runner(object): self.running = False self.finished = False self.exc_info = None + self.had_exception = False def register_callback(self, key): """Adds ``key`` to the list of callbacks.""" @@ -296,6 +297,7 @@ class Runner(object): self.exc_info = sys.exc_info() try: if self.exc_info is not None: + self.had_exception = True exc_info = self.exc_info self.exc_info = None yielded = self.gen.throw(*exc_info) @@ -303,7 +305,11 @@ class Runner(object): yielded = self.gen.send(next) except StopIteration: self.finished = True - if self.pending_callbacks: + if self.pending_callbacks and not self.had_exception: + # If we ran cleanly without waiting on all callbacks + # raise an error (really more of a warning). If we + # had an exception then some callbacks may have been + # orphaned, so skip the check in that case. raise LeakedCallbackError( "finished without waiting for callbacks %r" % self.pending_callbacks) @@ -315,7 +321,10 @@ class Runner(object): yielded = Multi(yielded) if isinstance(yielded, YieldPoint): self.yield_point = yielded - self.yield_point.start(self) + try: + self.yield_point.start(self) + except Exception: + self.exc_info = sys.exc_info() else: self.exc_info = (BadYieldError("yielded unknown object %r" % yielded),) finally: diff --git a/tornado/test/gen_test.py b/tornado/test/gen_test.py index 4c6a4d591..fec486705 100644 --- a/tornado/test/gen_test.py +++ b/tornado/test/gen_test.py @@ -57,6 +57,19 @@ class GenTest(AsyncTestCase): 1/0 self.assertRaises(ZeroDivisionError, self.run_gen, f) + def test_exception_in_task_phase1(self): + def fail_task(callback): + 1/0 + + @gen.engine + def f(): + try: + yield gen.Task(fail_task) + raise Exception("did not get expected exception") + except ZeroDivisionError: + self.stop() + self.run_gen(f) + def test_with_arg(self): @gen.engine def f():