]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add extra arg for tuple overloads
authorYurii Karabas <1998uriyyo@gmail.com>
Sat, 16 Dec 2023 20:38:05 +0000 (22:38 +0200)
committerYurii Karabas <1998uriyyo@gmail.com>
Sat, 16 Dec 2023 20:38:05 +0000 (22:38 +0200)
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/sql/_selectable_constructors.py
lib/sqlalchemy/sql/dml.py
tools/generate_tuple_map_overloads.py

index d18f9ad5cdba3ddc3d6e6b64ff892fbe1e960aa9..6da6f2af1add1e392473b33f2d759889fd78ca95 100644 (file)
@@ -1575,7 +1575,10 @@ class Query(
         __ent6: _TCCA[_T6],
         __ent7: _TCCA[_T7],
         /,
-    ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]:
+        *entities: _ColumnsClauseArgument[Any],
+    ) -> RowReturningQuery[
+        _T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, Unpack[TupleAny]
+    ]:
         ...
 
     # END OVERLOADED FUNCTIONS self.with_entities
index 092bde7ecc92f31c42e045e9b2877bf8f31f4f36..eb45fb9dbb70bc6476ecfe965e96e10f54050ebe 100644 (file)
@@ -2892,7 +2892,10 @@ class Session(_SessionClassMethods, EventTarget):
         __ent6: _TCCA[_T6],
         __ent7: _TCCA[_T7],
         /,
-    ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]:
+        *entities: _ColumnsClauseArgument[Any],
+    ) -> RowReturningQuery[
+        _T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, Unpack[TupleAny]
+    ]:
         ...
 
     # END OVERLOADED FUNCTIONS self.query
index 2df2ef1cb6e708ad94e1a3521150fb81f7bd9205..c8a31b643e9f8cc264f7478d2116ab84bdcb1e2a 100644 (file)
@@ -442,7 +442,10 @@ def select(
     __ent8: _TCCA[_T8],
     __ent9: _TCCA[_T9],
     /,
-) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]:
+    *entities: _ColumnsClauseArgument[Any],
+) -> Select[
+    _T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9, Unpack[TupleAny]
+]:
     ...
 
 
index cf605ae4c77eed6470b40274e1304b25c902311e..8d9f995c4cd24075604729d3f268c98c127be022 100644 (file)
@@ -1399,9 +1399,11 @@ class Insert(ValuesBase):
             __ent6: _TCCA[_T6],
             __ent7: _TCCA[_T7],
             /,
-            *,
+            *entities: _ColumnsClauseArgument[Any],
             sort_by_parameter_order: bool = False,
-        ) -> ReturningInsert[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]:
+        ) -> ReturningInsert[
+            _T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, Unpack[TupleAny]
+        ]:
             ...
 
         # END OVERLOADED FUNCTIONS self.returning
@@ -1688,7 +1690,10 @@ class Update(DMLWhereBase, ValuesBase):
             __ent6: _TCCA[_T6],
             __ent7: _TCCA[_T7],
             /,
-        ) -> ReturningUpdate[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]:
+            *entities: _ColumnsClauseArgument[Any],
+        ) -> ReturningUpdate[
+            _T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, Unpack[TupleAny]
+        ]:
             ...
 
         # END OVERLOADED FUNCTIONS self.returning
@@ -1831,7 +1836,10 @@ class Delete(DMLWhereBase, UpdateBase):
             __ent6: _TCCA[_T6],
             __ent7: _TCCA[_T7],
             /,
-        ) -> ReturningDelete[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]:
+            *entities: _ColumnsClauseArgument[Any],
+        ) -> ReturningDelete[
+            _T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, Unpack[TupleAny]
+        ]:
             ...
 
         # END OVERLOADED FUNCTIONS self.returning
index 0a300c20eebf9525a499b4f66e6927918dd91066..9ca648333cde1d8608f079df5f39da8a30a74315 100644 (file)
@@ -82,17 +82,26 @@ def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str:
                 )
 
                 for num_args in range(start_index, end_index + 1):
+                    ret_suffix = ""
                     combinations = [
                         f"__ent{arg}: _TCCA[_T{arg}]"
                         for arg in range(num_args)
                     ]
+
+                    if num_args == end_index:
+                        ret_suffix = ", Unpack[TupleAny]"
+                        extra_args = (
+                            f", *entities: _ColumnsClauseArgument[Any]"
+                            f"{extra_args.replace(', *', '')}"
+                        )
+
                     buf.write(
                         textwrap.indent(
                             f"""
 @overload
 def {current_fnname}(
     {'self, ' if use_self else ''}{", ".join(combinations)},/{extra_args}
-) -> {return_type}[{', '.join(f'_T{i}' for i in range(num_args))}]:
+) -> {return_type}[{', '.join(f'_T{i}' for i in range(num_args))}{ret_suffix}]:
     ...
 
 """,  # noqa: E501