]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add types
authorAlessio Bogon <youtux@gmail.com>
Sun, 1 Sep 2019 11:55:48 +0000 (13:55 +0200)
committerAlessio Bogon <youtux@gmail.com>
Sun, 1 Sep 2019 11:55:48 +0000 (13:55 +0200)
lib/sqlalchemy/ext/linter.py

index 4e53b0b164203774e40573137209d84138193b6c..160731745d48fcd63db64155156082afb09a4009 100644 (file)
@@ -12,15 +12,17 @@ def before_execute_hook(conn, clauseelement, multiparams, params):
         lint(clauseelement)
 
 
-def find_unmatching_froms(element, start_with=None):
+def find_unmatching_froms(query, start_with=None):
+    # type: (Select, Optional[FromClause]) -> Tuple[Set[FromClause], FromClause]
     # TODO: It would be nicer to use OrderedSet, but it seems to not be too much optimize, so let's skip for now
-    froms = set(element.froms)
+    froms = set(query.froms)
     if not froms:
-        return
+        return None, None
     edges = set()
 
     # find all "a <operator> b", add that as edges
     def visit_binary(binary_element):
+        # type: (BinaryExpression) -> None
         edges.update(
             itertools.product(
                 binary_element.left._from_objects,
@@ -30,12 +32,14 @@ def find_unmatching_froms(element, start_with=None):
 
     # find all "a JOIN b", add "a" and "b" as froms
     def visit_join(join_element):
+        # type: (Join) -> None
         if join_element in froms:
             froms.remove(join_element)
             froms.update((join_element.left, join_element.right))
 
     # unwrap "FromGrouping" objects, e.g. parentheized froms
     def visit_grouping(grouping_element):
+        # type: (FromGrouping) -> None
         if grouping_element in froms:
             froms.remove(grouping_element)
 
@@ -45,7 +49,7 @@ def find_unmatching_froms(element, start_with=None):
             froms.add(grouping_element.element)
 
     visitors.traverse(
-        element,
+        query,
         {},
         {
             "binary": visit_binary,
@@ -86,6 +90,7 @@ def find_unmatching_froms(element, start_with=None):
 
 
 def warn_for_unmatching_froms(query):
+    # type: (Select) -> None
     froms, start_with = find_unmatching_froms(query)
     if froms:
         util.warn(
@@ -99,4 +104,5 @@ def warn_for_unmatching_froms(query):
 
 
 def lint(query):
+    # type: (Select) -> None
     warn_for_unmatching_froms(query)