class _ClassFinder(ast.NodeVisitor):
- def __init__(self, qualname):
+ def __init__(self, cls, tree, lines, qualname):
self.stack = []
+ self.cls = cls
+ self.tree = tree
+ self.lines = lines
self.qualname = qualname
+ self.lineno_found = []
def visit_FunctionDef(self, node):
self.stack.append(node.name)
line_number = node.lineno
# decrement by one since lines starts with indexing by zero
- line_number -= 1
- raise ClassFoundException(line_number)
+ self.lineno_found.append((line_number - 1, node.end_lineno))
self.generic_visit(node)
self.stack.pop()
+ def get_lineno(self):
+ self.visit(self.tree)
+ lineno_found_number = len(self.lineno_found)
+ if lineno_found_number == 0:
+ raise OSError('could not find class definition')
+ elif lineno_found_number == 1:
+ return self.lineno_found[0][0]
+ else:
+ # We have multiple candidates for the class definition.
+ # Now we have to guess.
+
+ # First, let's see if there are any method definitions
+ for member in self.cls.__dict__.values():
+ if isinstance(member, types.FunctionType):
+ for lineno, end_lineno in self.lineno_found:
+ if lineno <= member.__code__.co_firstlineno <= end_lineno:
+ return lineno
+
+ class_strings = [(''.join(self.lines[lineno: end_lineno]), lineno)
+ for lineno, end_lineno in self.lineno_found]
+
+ # Maybe the class has a docstring and it's unique?
+ if self.cls.__doc__:
+ ret = None
+ for candidate, lineno in class_strings:
+ if self.cls.__doc__.strip() in candidate:
+ if ret is None:
+ ret = lineno
+ else:
+ break
+ else:
+ if ret is not None:
+ return ret
+
+ # We are out of ideas, just return the last one found, which is
+ # slightly better than previous ones
+ return self.lineno_found[-1][0]
+
def findsource(object):
"""Return the entire source file and starting line number for an object.
qualname = object.__qualname__
source = ''.join(lines)
tree = ast.parse(source)
- class_finder = _ClassFinder(qualname)
- try:
- class_finder.visit(tree)
- except ClassFoundException as e:
- line_number = e.args[0]
- return lines, line_number
- else:
- raise OSError('could not find class definition')
+ class_finder = _ClassFinder(object, tree, lines, qualname)
+ return lines, class_finder.get_lineno()
if ismethod(object):
object = object.__func__
self.assertSourceEqual(mod2.cls196.cls200, 198, 201)
def test_class_inside_conditional(self):
- self.assertSourceEqual(mod2.cls238, 238, 240)
self.assertSourceEqual(mod2.cls238.cls239, 239, 240)
def test_multiple_children_classes(self):
self.assertSourceEqual(mod2.cls226, 231, 235)
self.assertSourceEqual(asyncio.run(mod2.cls226().func232()), 233, 234)
+ def test_class_definition_same_name_diff_methods(self):
+ self.assertSourceEqual(mod2.cls296, 296, 298)
+ self.assertSourceEqual(mod2.cls310, 310, 312)
+
class TestNoEOL(GetSourceBase):
def setUp(self):
self.tempdir = TESTFN + '_dir'