]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix type annotation of schema.MetaData.naming_convention
authorFederico Caselli <cfederico87@gmail.com>
Mon, 7 Aug 2023 20:37:12 +0000 (22:37 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 8 Aug 2023 19:22:06 +0000 (21:22 +0200)
Fixed #9600
Closes: #9598
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9598
Pull-request-sha: d458e36242c60f44b1140697ae04cc8ad8ac5b98

Change-Id: I2290d1955c21f8d40765f78e486c739263f6e4ab

lib/sqlalchemy/sql/schema.py
test/typing/plain_files/sql/schema.py [new file with mode: 0644]

index 721b9ee63789d2ac8be2914f6b078df0f5dbed75..008ae2c0059b722e8dbf302e3684fa09e4cb5c24 100644 (file)
@@ -43,12 +43,14 @@ from typing import Dict
 from typing import Iterable
 from typing import Iterator
 from typing import List
+from typing import Mapping
 from typing import NoReturn
 from typing import Optional
 from typing import overload
 from typing import Sequence as _typing_Sequence
 from typing import Set
 from typing import Tuple
+from typing import Type
 from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
@@ -85,6 +87,7 @@ from ..util.typing import Final
 from ..util.typing import Literal
 from ..util.typing import Protocol
 from ..util.typing import Self
+from ..util.typing import TypedDict
 from ..util.typing import TypeGuard
 
 if typing.TYPE_CHECKING:
@@ -5283,8 +5286,35 @@ class Index(
         )
 
 
-DEFAULT_NAMING_CONVENTION: util.immutabledict[str, str] = util.immutabledict(
-    {"ix": "ix_%(column_0_label)s"}
+_AllConstraints = Union[
+    Index,
+    UniqueConstraint,
+    CheckConstraint,
+    ForeignKeyConstraint,
+    PrimaryKeyConstraint,
+]
+
+_NamingSchemaCallable = Callable[[_AllConstraints, Table], str]
+
+
+class _NamingSchemaTD(TypedDict, total=False):
+    fk: Union[str, _NamingSchemaCallable]
+    pk: Union[str, _NamingSchemaCallable]
+    ix: Union[str, _NamingSchemaCallable]
+    ck: Union[str, _NamingSchemaCallable]
+    uq: Union[str, _NamingSchemaCallable]
+
+
+_NamingSchemaParameter = Union[
+    _NamingSchemaTD,
+    Mapping[
+        Union[Type[_AllConstraints], str], Union[str, _NamingSchemaCallable]
+    ],
+]
+
+
+DEFAULT_NAMING_CONVENTION: _NamingSchemaParameter = util.immutabledict(
+    {"ix": "ix_%(column_0_label)s"}  # type: ignore[arg-type]
 )
 
 
@@ -5319,7 +5349,7 @@ class MetaData(HasSchemaAttr):
         self,
         schema: Optional[str] = None,
         quote_schema: Optional[bool] = None,
-        naming_convention: Optional[Dict[str, str]] = None,
+        naming_convention: Optional[_NamingSchemaParameter] = None,
         info: Optional[_InfoType] = None,
     ) -> None:
         """Create a new MetaData object.
diff --git a/test/typing/plain_files/sql/schema.py b/test/typing/plain_files/sql/schema.py
new file mode 100644 (file)
index 0000000..1e0a134
--- /dev/null
@@ -0,0 +1,33 @@
+from typing import Union
+
+from sqlalchemy import Constraint
+from sqlalchemy import Index
+from sqlalchemy import MetaData
+from sqlalchemy import Table
+
+MetaData(
+    naming_convention={
+        "ix": "ix_%(column_0_label)s",
+        "uq": "uq_%(table_name)s_%(column_0_name)s",
+        "ck": "ck_%(table_name)s_%(constraint_name)s",
+        "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
+        "pk": "pk_%(table_name)s",
+    }
+)
+
+
+MetaData(naming_convention={"uq": "uq_%(table_name)s_%(column_0_N_name)s"})
+
+
+def fk_guid(constraint: Union[Constraint, Index], table: Table) -> str:
+    return "foo"
+
+
+MetaData(
+    naming_convention={
+        "fk_guid": fk_guid,
+        "ix": "ix_%(column_0_label)s",
+        "fk": "fk_%(fk_guid)s",
+        "foo": lambda c, t: t.name + str(c.name),
+    }
+)