From: Mike Bayer Date: Wed, 18 Dec 2024 22:19:56 +0000 (-0500) Subject: typing fix: allow stmt.excluded for set_ X-Git-Tag: rel_2_0_37~16 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=02bd039796264268de38f7f293b95dbb13ca99f1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git typing fix: allow stmt.excluded for set_ Change-Id: I6f0af23fba8f5868282505438e6ca0a5af7e1bbe (cherry picked from commit 5c79e5ce2dd9db491e9177e7f5af0a83058ebe06) --- diff --git a/lib/sqlalchemy/dialects/_typing.py b/lib/sqlalchemy/dialects/_typing.py index 811e125fd5..8e04f3b376 100644 --- a/lib/sqlalchemy/dialects/_typing.py +++ b/lib/sqlalchemy/dialects/_typing.py @@ -13,6 +13,7 @@ from typing import Optional from typing import Union from ..sql import roles +from ..sql.base import ColumnCollection from ..sql.schema import Column from ..sql.schema import ColumnCollectionConstraint from ..sql.schema import Index @@ -23,5 +24,7 @@ _OnConflictIndexElementsT = Optional[ Iterable[Union[Column[Any], str, roles.DDLConstraintColumnRole]] ] _OnConflictIndexWhereT = Optional[roles.WhereHavingRole] -_OnConflictSetT = Optional[Mapping[Any, Any]] +_OnConflictSetT = Optional[ + Union[Mapping[Any, Any], ColumnCollection[Any, Any]] +] _OnConflictWhereT = Optional[roles.WhereHavingRole] diff --git a/test/typing/plain_files/dialects/postgresql/pg_stuff.py b/test/typing/plain_files/dialects/postgresql/pg_stuff.py index 678d22b71f..5e56efba98 100644 --- a/test/typing/plain_files/dialects/postgresql/pg_stuff.py +++ b/test/typing/plain_files/dialects/postgresql/pg_stuff.py @@ -81,6 +81,9 @@ insert(Test).on_conflict_do_nothing( unique, ["foo"], Test.id > 0, {"id": 42, Test.ident: 99}, Test.id == 22 ).excluded.foo.desc() +s1 = insert(Test) +s1.on_conflict_do_update(set_=s1.excluded) + # EXPECTED_TYPE: Column[Range[int]] reveal_type(Column(INT4RANGE())) diff --git a/test/typing/plain_files/dialects/sqlite/sqlite_stuff.py b/test/typing/plain_files/dialects/sqlite/sqlite_stuff.py index 00debda509..456f402937 100644 --- a/test/typing/plain_files/dialects/sqlite/sqlite_stuff.py +++ b/test/typing/plain_files/dialects/sqlite/sqlite_stuff.py @@ -21,3 +21,6 @@ unique = UniqueConstraint(name="my_constraint") insert(Test).on_conflict_do_nothing("foo", Test.id > 0).on_conflict_do_update( unique, Test.id > 0, {"id": 42, Test.data: 99}, Test.id == 22 ).excluded.foo.desc() + +s1 = insert(Test) +s1.on_conflict_do_update(set_=s1.excluded)