From 008e605ec6431e696158ed65adb89eb85a65aa7f Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Sun, 1 Jan 2012 16:17:35 -0800 Subject: [PATCH] Allow exceptions thrown in the first (synchronous) phase of a gen.Task to be caught by the generator. --- tornado/gen.py | 13 +++++++++++-- tornado/test/gen_test.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/tornado/gen.py b/tornado/gen.py index 6f6b49aa5..59684f50f 100644 --- a/tornado/gen.py +++ b/tornado/gen.py @@ -255,6 +255,7 @@ class Runner(object): self.running = False self.finished = False self.exc_info = None + self.had_exception = False def register_callback(self, key): """Adds ``key`` to the list of callbacks.""" @@ -296,6 +297,7 @@ class Runner(object): self.exc_info = sys.exc_info() try: if self.exc_info is not None: + self.had_exception = True exc_info = self.exc_info self.exc_info = None yielded = self.gen.throw(*exc_info) @@ -303,7 +305,11 @@ class Runner(object): yielded = self.gen.send(next) except StopIteration: self.finished = True - if self.pending_callbacks: + if self.pending_callbacks and not self.had_exception: + # If we ran cleanly without waiting on all callbacks + # raise an error (really more of a warning). If we + # had an exception then some callbacks may have been + # orphaned, so skip the check in that case. raise LeakedCallbackError( "finished without waiting for callbacks %r" % self.pending_callbacks) @@ -315,7 +321,10 @@ class Runner(object): yielded = Multi(yielded) if isinstance(yielded, YieldPoint): self.yield_point = yielded - self.yield_point.start(self) + try: + self.yield_point.start(self) + except Exception: + self.exc_info = sys.exc_info() else: self.exc_info = (BadYieldError("yielded unknown object %r" % yielded),) finally: diff --git a/tornado/test/gen_test.py b/tornado/test/gen_test.py index 4c6a4d591..fec486705 100644 --- a/tornado/test/gen_test.py +++ b/tornado/test/gen_test.py @@ -57,6 +57,19 @@ class GenTest(AsyncTestCase): 1/0 self.assertRaises(ZeroDivisionError, self.run_gen, f) + def test_exception_in_task_phase1(self): + def fail_task(callback): + 1/0 + + @gen.engine + def f(): + try: + yield gen.Task(fail_task) + raise Exception("did not get expected exception") + except ZeroDivisionError: + self.stop() + self.run_gen(f) + def test_with_arg(self): @gen.engine def f(): -- 2.47.2