]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Improve typing to accommodate sqlalchemy v2
authorCaselIT <cfederico87@gmail.com>
Fri, 10 Feb 2023 21:24:11 +0000 (22:24 +0100)
committermike bayer <mike_mp@zzzcomputing.com>
Sun, 26 Feb 2023 01:50:04 +0000 (01:50 +0000)
Index name can be null.

Fixes: #1168
Change-Id: Id7c944e19a9facd7d3862d43f84fd70aedace999

alembic/autogenerate/compare.py
alembic/autogenerate/render.py
alembic/ddl/postgresql.py
alembic/op.pyi
alembic/operations/batch.py
alembic/operations/ops.py
alembic/operations/schemaobj.py
alembic/util/sqla_compat.py
setup.cfg
tox.ini

index 828a4cd5f5557ebf71526d9f29d679ab1559bbdf..c2181b8cf94cd674739bac7e2c1aec687ff18640 100644 (file)
@@ -640,16 +640,15 @@ def _compare_indexes_and_uniques(
         or sqla_compat._constraint_is_named(c.const, autogen_context.dialect)
     }
 
+    conn_uniques_by_name: Dict[sqla_compat._ConstraintName, _uq_constraint_sig]
+    conn_indexes_by_name: Dict[sqla_compat._ConstraintName, _ix_constraint_sig]
+
     conn_uniques_by_name = {c.name: c for c in conn_unique_constraints}
-    conn_indexes_by_name: Dict[Optional[str], _ix_constraint_sig] = {
-        c.name: c for c in conn_indexes_sig
-    }
+    conn_indexes_by_name = {c.name: c for c in conn_indexes_sig}
     conn_names = {
         c.name: c
-        for c in conn_unique_constraints.union(
-            conn_indexes_sig  # type:ignore[arg-type]
-        )
-        if c.name is not None
+        for c in conn_unique_constraints.union(conn_indexes_sig)
+        if sqla_compat.constraint_name_defined(c.name)
     }
 
     doubled_constraints = {
index 4a144db74a7c881368ccc106aef265820a464544..dc841f83ef38bc10f21c60c0de704a72d4dd72b5 100644 (file)
@@ -17,6 +17,7 @@ from sqlalchemy import schema as sa_schema
 from sqlalchemy import sql
 from sqlalchemy import types as sqltypes
 from sqlalchemy.sql.elements import conv
+from sqlalchemy.sql.elements import quoted_name
 
 from .. import util
 from ..operations import ops
@@ -26,12 +27,10 @@ if TYPE_CHECKING:
     from typing import Literal
 
     from sqlalchemy.sql.elements import ColumnElement
-    from sqlalchemy.sql.elements import quoted_name
     from sqlalchemy.sql.elements import TextClause
     from sqlalchemy.sql.schema import CheckConstraint
     from sqlalchemy.sql.schema import Column
     from sqlalchemy.sql.schema import Constraint
-    from sqlalchemy.sql.schema import DefaultClause
     from sqlalchemy.sql.schema import FetchedValue
     from sqlalchemy.sql.schema import ForeignKey
     from sqlalchemy.sql.schema import ForeignKeyConstraint
@@ -55,12 +54,12 @@ MAX_PYTHON_ARGS = 255
 
 def _render_gen_name(
     autogen_context: AutogenContext,
-    name: Optional[Union[quoted_name, str]],
+    name: sqla_compat._ConstraintName,
 ) -> Optional[Union[quoted_name, str, _f_name]]:
     if isinstance(name, conv):
         return _f_name(_alembic_autogenerate_prefix(autogen_context), name)
     else:
-        return name
+        return sqla_compat.constraint_name_or_none(name)
 
 
 def _indent(text: str) -> str:
@@ -554,7 +553,7 @@ def _ident(name: Optional[Union[quoted_name, str]]) -> Optional[str]:
     """
     if name is None:
         return name
-    elif isinstance(name, sql.elements.quoted_name):
+    elif isinstance(name, quoted_name):
         return str(name)
     elif isinstance(name, str):
         return name
@@ -721,9 +720,7 @@ def _render_column(column: Column, autogen_context: AutogenContext) -> str:
     }
 
 
-def _should_render_server_default_positionally(
-    server_default: Union[Computed, DefaultClause]
-) -> bool:
+def _should_render_server_default_positionally(server_default: Any) -> bool:
     return sqla_compat._server_default_is_computed(
         server_default
     ) or sqla_compat._server_default_is_identity(server_default)
index 32674d2a677f2d9bdb25390169372a4b585d6fd0..e7c85bdc13c21a59d849e0e9bd76ba118dee1ca5 100644 (file)
@@ -419,7 +419,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
 
     def __init__(
         self,
-        constraint_name: Optional[str],
+        constraint_name: sqla_compat._ConstraintName,
         table_name: Union[str, quoted_name],
         elements: Union[
             Sequence[Tuple[str, str]],
@@ -443,7 +443,6 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
         cls, constraint: ExcludeConstraint
     ) -> CreateExcludeConstraintOp:
         constraint_table = sqla_compat._table_for_constraint(constraint)
-
         return cls(
             constraint.name,
             constraint_table.name,
@@ -451,7 +450,9 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
                 (expr, op)
                 for expr, name, op in constraint._render_exprs  # type:ignore[attr-defined] # noqa
             ],
-            where=constraint.where,
+            where=cast(
+                "Optional[Union[BinaryExpression, str]]", constraint.where
+            ),
             schema=constraint_table.schema,
             _orig_constraint=constraint,
             deferrable=constraint.deferrable,
index 5c089e83c0335666c92bf986d107d2d469f7d39a..7a5710eb1f5ccd31dddb4f1b684a50e0032f5baa 100644 (file)
@@ -576,7 +576,7 @@ def create_foreign_key(
     """
 
 def create_index(
-    index_name: str,
+    index_name: Optional[str],
     table_name: str,
     columns: Sequence[Union[str, TextClause, Function]],
     schema: Optional[str] = None,
index 0c773c68ccc8de3e9df914215869318cb5cc19c8..fe32eec2672889423b78e3cbfe402aa9e2f55109 100644 (file)
@@ -33,6 +33,7 @@ from ..util.sqla_compat import _is_type_bound
 from ..util.sqla_compat import _remove_column_from_collection
 from ..util.sqla_compat import _resolve_for_variant
 from ..util.sqla_compat import _select
+from ..util.sqla_compat import constraint_name_defined
 
 if TYPE_CHECKING:
     from typing import Literal
@@ -268,7 +269,7 @@ class ApplyBatchImpl:
                 # because
                 # we have no way to determine _is_type_bound() for these.
                 pass
-            elif const.name:
+            elif constraint_name_defined(const.name):
                 self.named_constraints[const.name] = const
             else:
                 self.unnamed_constraints.append(const)
@@ -662,7 +663,7 @@ class ApplyBatchImpl:
         """
 
     def add_constraint(self, const: Constraint) -> None:
-        if not const.name:
+        if not constraint_name_defined(const.name):
             raise ValueError("Constraint must have a name")
         if isinstance(const, sql_schema.PrimaryKeyConstraint):
             if self.table.primary_key in self.unnamed_constraints:
@@ -681,7 +682,7 @@ class ApplyBatchImpl:
                     if col_const.name == const.name:
                         self.columns[col.name].constraints.remove(col_const)
             else:
-                assert const.name
+                assert constraint_name_defined(const.name)
                 const = self.named_constraints.pop(const.name)
         except KeyError:
             if _is_type_bound(const):
index 3cdd170d35ede485d9cfdb57e539518601990cf9..8e1144bb07a6224c829d4fc928b553f4b8de76f6 100644 (file)
@@ -154,10 +154,7 @@ class DropConstraintOp(MigrateOperation):
             return ("remove_constraint", self.to_constraint())
 
     @classmethod
-    def from_constraint(
-        cls,
-        constraint: Constraint,
-    ) -> DropConstraintOp:
+    def from_constraint(cls, constraint: Constraint) -> DropConstraintOp:
         types = {
             "unique_constraint": "unique",
             "foreign_key_constraint": "foreignkey",
@@ -169,7 +166,7 @@ class DropConstraintOp(MigrateOperation):
 
         constraint_table = sqla_compat._table_for_constraint(constraint)
         return cls(
-            constraint.name,
+            sqla_compat.constraint_name_or_none(constraint.name),
             constraint_table.name,
             schema=constraint_table.schema,
             type_=types[constraint.__visit_name__],
@@ -274,9 +271,8 @@ class CreatePrimaryKeyOp(AddConstraintOp):
     def from_constraint(cls, constraint: Constraint) -> CreatePrimaryKeyOp:
         constraint_table = sqla_compat._table_for_constraint(constraint)
         pk_constraint = cast("PrimaryKeyConstraint", constraint)
-
         return cls(
-            pk_constraint.name,
+            sqla_compat.constraint_name_or_none(pk_constraint.name),
             constraint_table.name,
             pk_constraint.columns.keys(),
             schema=constraint_table.schema,
@@ -411,7 +407,7 @@ class CreateUniqueConstraintOp(AddConstraintOp):
             kw["initially"] = uq_constraint.initially
         kw.update(uq_constraint.dialect_kwargs)
         return cls(
-            uq_constraint.name,
+            sqla_compat.constraint_name_or_none(uq_constraint.name),
             constraint_table.name,
             [c.name for c in uq_constraint.columns],
             schema=constraint_table.schema,
@@ -567,7 +563,7 @@ class CreateForeignKeyOp(AddConstraintOp):
         kw["referent_schema"] = target_schema
         kw.update(fk_constraint.dialect_kwargs)
         return cls(
-            fk_constraint.name,
+            sqla_compat.constraint_name_or_none(fk_constraint.name),
             source_table,
             target_table,
             source_columns,
@@ -753,9 +749,8 @@ class CreateCheckConstraintOp(AddConstraintOp):
         constraint_table = sqla_compat._table_for_constraint(constraint)
 
         ck_constraint = cast("CheckConstraint", constraint)
-
         return cls(
-            ck_constraint.name,
+            sqla_compat.constraint_name_or_none(ck_constraint.name),
             constraint_table.name,
             cast("ColumnElement[Any]", ck_constraint.sqltext),
             schema=constraint_table.schema,
@@ -863,7 +858,7 @@ class CreateIndexOp(MigrateOperation):
 
     def __init__(
         self,
-        index_name: str,
+        index_name: Optional[str],
         table_name: str,
         columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
         schema: Optional[str] = None,
@@ -914,7 +909,7 @@ class CreateIndexOp(MigrateOperation):
     def create_index(
         cls,
         operations: Operations,
-        index_name: str,
+        index_name: Optional[str],
         table_name: str,
         columns: Sequence[Union[str, TextClause, Function]],
         schema: Optional[str] = None,
index dfda8bbeaad0d2910ef3bb1dc53229269693cef6..ba09b3bb1634cc738c58a1c8394b78a54e81fd29 100644 (file)
@@ -235,7 +235,7 @@ class SchemaObjects:
 
     def index(
         self,
-        name: str,
+        name: Optional[str],
         tablename: Optional[str],
         columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
         schema: Optional[str] = None,
index 23255be3bcbfc8d9627461ac9ec0f59ee44de9ca..9bdcfc3be7d8efb1de8faba752cefe67b1e50d76 100644 (file)
@@ -26,6 +26,7 @@ from sqlalchemy.sql.elements import BindParameter
 from sqlalchemy.sql.elements import quoted_name
 from sqlalchemy.sql.elements import TextClause
 from sqlalchemy.sql.visitors import traverse
+from typing_extensions import TypeGuard
 
 if TYPE_CHECKING:
     from sqlalchemy import Index
@@ -103,6 +104,22 @@ else:
     _identity_attrs = _identity_options_attrs + ("on_null",)
     has_identity = True
 
+if sqla_2:
+    from sqlalchemy.sql.base import _NoneName
+else:
+    from sqlalchemy.util import symbol as _NoneName  # type: ignore[assignment]
+
+_ConstraintName = Union[None, str, _NoneName]
+
+
+def constraint_name_defined(name: _ConstraintName) -> TypeGuard[str]:
+    return isinstance(name, str)
+
+
+def constraint_name_or_none(name: _ConstraintName) -> Optional[str]:
+    return name if constraint_name_defined(name) else None
+
+
 AUTOINCREMENT_DEFAULT = "auto"
 
 
index 0d9ce1a72881548f94c704d75ebfe61dce39f2f6..4741eb726586b4ab56810415c18698cee3004312 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -43,6 +43,7 @@ install_requires =
     Mako
     importlib-metadata;python_version<"3.9"
     importlib-resources;python_version<"3.9"
+    typing-extensions>=4
 
 [options.extras_require]
 tz =
diff --git a/tox.ini b/tox.ini
index 8b744d7c6fdcd542bc7a7bfd576926377d8ba85b..4cc54450247189651f246ed5bd05a71d87b15b15 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -70,8 +70,7 @@ commands=
 basepython = python3
 deps=
     mypy
-    sqlalchemy>=1.4.0
-    sqlalchemy2-stubs
+    sqlalchemy>=2
     mako
     types-pkg-resources
     types-python-dateutil