]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Removed server default quoting from compare
authorFederico Caselli <cfederico87@gmail.com>
Thu, 11 May 2023 19:49:14 +0000 (21:49 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 11 May 2023 20:41:09 +0000 (22:41 +0200)
Don't modify the metadata server default when comparing it in the
autogenerate process.
This impacts the value passes to user provided functions passed in
:paramref:`.EnvironmentContext.configure.compare_server_default`
and third party dialect that implement a custom ``compare_server_default``.

Fixes: #1178
Change-Id: Ib429efcf9077337f768ad5aad91659867e89391a

alembic/autogenerate/compare.py
alembic/context.pyi
alembic/ddl/mysql.py
alembic/runtime/environment.py
docs/build/unreleased/1178.rst [new file with mode: 0644]
tests/test_postgresql.py

index 595631ceb550e64d0857d9246efcb70e3ef6ff07..b489328bd005c74fbfbbb4529301a06f1b44811c 100644 (file)
@@ -1023,9 +1023,7 @@ def _compare_type(
 
 
 def _render_server_default_for_compare(
-    metadata_default: Optional[Any],
-    metadata_col: Column,
-    autogen_context: AutogenContext,
+    metadata_default: Optional[Any], autogen_context: AutogenContext
 ) -> Optional[str]:
 
     if isinstance(metadata_default, sa_schema.DefaultClause):
@@ -1039,11 +1037,7 @@ def _render_server_default_for_compare(
                 )
             )
     if isinstance(metadata_default, str):
-        if metadata_col.type._type_affinity is sqltypes.String:
-            metadata_default = re.sub(r"^'|'$", "", metadata_default)
-            return f"'{metadata_default}'"
-        else:
-            return metadata_default
+        return metadata_default
     else:
         return None
 
@@ -1190,7 +1184,7 @@ def _compare_server_default(
                 )
     else:
         rendered_metadata_default = _render_server_default_for_compare(
-            metadata_default, metadata_col, autogen_context
+            metadata_default, autogen_context
         )
 
         rendered_conn_default = (
index c81a14fd06bda5a5359f751f7d1cfc03ac1c5cb9..621599d345ba320f7923c9b52ca6a72d7f48cec4 100644 (file)
@@ -22,6 +22,8 @@ if TYPE_CHECKING:
     from sqlalchemy.engine.base import Connection
     from sqlalchemy.engine.url import URL
     from sqlalchemy.sql.elements import ClauseElement
+    from sqlalchemy.sql.schema import Column
+    from sqlalchemy.sql.schema import FetchedValue
     from sqlalchemy.sql.schema import MetaData
     from sqlalchemy.sql.schema import SchemaItem
 
@@ -144,7 +146,20 @@ def configure(
         ]
     ] = None,
     compare_type: bool = False,
-    compare_server_default: bool = False,
+    compare_server_default: Union[
+        bool,
+        Callable[
+            [
+                MigrationContext,
+                Column,
+                Column,
+                Optional[str],
+                Optional[FetchedValue],
+                Optional[str],
+            ],
+            Optional[bool],
+        ],
+    ] = False,
     render_item: Optional[
         Callable[[str, Any, AutogenContext], Union[str, Literal[False]]]
     ] = None,
index a452760227227789a7d80312d154a145fc4626b4..5e66f53823bc29f5a028b682d532a2a5deea8cdd 100644 (file)
@@ -185,13 +185,22 @@ class MySQLImpl(DefaultImpl):
             and rendered_inspector_default == "'0'"
         ):
             return False
-        elif inspector_column.type._type_affinity is sqltypes.Integer:
+        elif (
+            rendered_inspector_default
+            and inspector_column.type._type_affinity is sqltypes.Integer
+        ):
             rendered_inspector_default = (
                 re.sub(r"^'|'$", "", rendered_inspector_default)
                 if rendered_inspector_default is not None
                 else None
             )
             return rendered_inspector_default != rendered_metadata_default
+        elif (
+            rendered_metadata_default
+            and metadata_column.type._type_affinity is sqltypes.String
+        ):
+            metadata_default = re.sub(r"^'|'$", "", rendered_metadata_default)
+            return rendered_inspector_default != f"'{metadata_default}'"
         elif rendered_inspector_default and rendered_metadata_default:
             # adjust for "function()" vs. "FUNCTION" as can occur particularly
             # for the CURRENT_TIMESTAMP function on newer MariaDB versions
index 71a53091111f68cf290e20112663598f665e26f8..3087377361dd4ac42cfd3435af4c25b739a95740 100644 (file)
@@ -15,6 +15,8 @@ from typing import Tuple
 from typing import TYPE_CHECKING
 from typing import Union
 
+from sqlalchemy.sql.schema import Column
+from sqlalchemy.sql.schema import FetchedValue
 from typing_extensions import Literal
 
 from .migration import _ProxyTransaction
@@ -79,6 +81,18 @@ OnVersionApplyFn = Callable[
     None,
 ]
 
+CompareServerDefault = Callable[
+    [
+        MigrationContext,
+        Column,
+        Column,
+        Optional[str],
+        Optional[FetchedValue],
+        Optional[str],
+    ],
+    Optional[bool],
+]
+
 
 class EnvironmentContext(util.ModuleClsProxy):
 
@@ -398,7 +412,7 @@ class EnvironmentContext(util.ModuleClsProxy):
             ProcessRevisionDirectiveFn
         ] = None,
         compare_type: bool = False,
-        compare_server_default: bool = False,
+        compare_server_default: Union[bool, CompareServerDefault] = False,
         render_item: Optional[RenderItemFn] = None,
         literal_binds: bool = False,
         upgrade_token: str = "upgrades",
diff --git a/docs/build/unreleased/1178.rst b/docs/build/unreleased/1178.rst
new file mode 100644 (file)
index 0000000..25789d3
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: changed, autogenerate
+    :tickets: 1178
+
+    Don't modify the metadata server default when comparing it in the
+    autogenerate process.
+    This impacts the value passes to user provided functions passed in
+    :paramref:`.EnvironmentContext.configure.compare_server_default`
+    and third party dialect that implement a custom ``compare_server_default``.
index 77ed4dabfcefdb46e4eb065872645f7eff5364e2..818dae7cd4a7d88aad4857b4afcbd886a515ce19 100644 (file)
@@ -846,7 +846,7 @@ class PostgresqlDetectSerialTest(TestBase):
 
         eq_(
             _render_server_default_for_compare(
-                tab.c.x.server_default, tab.c.x, self.autogen_context
+                tab.c.x.server_default, self.autogen_context
             ),
             c_expected,
         )
@@ -867,7 +867,7 @@ class PostgresqlDetectSerialTest(TestBase):
         server_default = diffs[0][0][4]["existing_server_default"]
         eq_(
             _render_server_default_for_compare(
-                server_default, tab.c.x, self.autogen_context
+                server_default, self.autogen_context
             ),
             c_expected,
         )