]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
🔧 Add sqlmodel.sql.expression generation script (select overloads)
authorSebastián Ramírez <tiangolo@gmail.com>
Tue, 24 Aug 2021 13:16:41 +0000 (15:16 +0200)
committerSebastián Ramírez <tiangolo@gmail.com>
Tue, 24 Aug 2021 13:16:41 +0000 (15:16 +0200)
scripts/generate_select.py [new file with mode: 0644]

diff --git a/scripts/generate_select.py b/scripts/generate_select.py
new file mode 100644 (file)
index 0000000..b66a167
--- /dev/null
@@ -0,0 +1,55 @@
+from itertools import product
+from pathlib import Path
+from typing import List, Tuple
+
+import black
+from jinja2 import Template
+from pydantic import BaseModel
+
+template_path = Path(__file__).parent.parent / "sqlmodel/sql/expression.py.jinja2"
+destiny_path = Path(__file__).parent.parent / "sqlmodel/sql/expression.py"
+
+
+number_of_types = 4
+
+
+class Arg(BaseModel):
+    name: str
+    annotation: str
+
+
+arg_groups: List[Arg] = []
+
+signatures: List[Tuple[List[Arg], List[str]]] = []
+
+for total_args in range(2, number_of_types + 1):
+    arg_types_tuples = product(["scalar", "model"], repeat=total_args)
+    for arg_type_tuple in arg_types_tuples:
+        args: List[Arg] = []
+        return_types: List[str] = []
+        for i, arg_type in enumerate(arg_type_tuple):
+            if arg_type == "scalar":
+                t_var = f"_TScalar_{i}"
+                arg = Arg(name=f"entity_{i}", annotation=t_var)
+                ret_type = t_var
+            else:
+                t_type = f"_TModel_{i}"
+                t_var = f"Type[{t_type}]"
+                arg = Arg(name=f"entity_{i}", annotation=t_var)
+                ret_type = t_type
+            args.append(arg)
+            return_types.append(ret_type)
+        signatures.append((args, return_types))
+
+template: Template = Template(template_path.read_text())
+
+result = template.render(number_of_types=number_of_types, signatures=signatures)
+
+result = (
+    "# WARNING: do not modify this code, it is generated by "
+    "expression.py.jinja2\n\n" + result
+)
+
+result = black.format_str(result, mode=black.Mode())
+
+destiny_path.write_text(result)