From: Federico Caselli Date: Thu, 11 May 2023 19:49:14 +0000 (+0200) Subject: Removed server default quoting from compare X-Git-Tag: rel_1_11_0~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=230a2932f646800b006c00b434be95c164598525;p=thirdparty%2Fsqlalchemy%2Falembic.git Removed server default quoting from compare Don't modify the metadata server default when comparing it in the autogenerate process. This impacts the value passes to user provided functions passed in :paramref:`.EnvironmentContext.configure.compare_server_default` and third party dialect that implement a custom ``compare_server_default``. Fixes: #1178 Change-Id: Ib429efcf9077337f768ad5aad91659867e89391a --- diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index 595631ce..b489328b 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -1023,9 +1023,7 @@ def _compare_type( def _render_server_default_for_compare( - metadata_default: Optional[Any], - metadata_col: Column, - autogen_context: AutogenContext, + metadata_default: Optional[Any], autogen_context: AutogenContext ) -> Optional[str]: if isinstance(metadata_default, sa_schema.DefaultClause): @@ -1039,11 +1037,7 @@ def _render_server_default_for_compare( ) ) if isinstance(metadata_default, str): - if metadata_col.type._type_affinity is sqltypes.String: - metadata_default = re.sub(r"^'|'$", "", metadata_default) - return f"'{metadata_default}'" - else: - return metadata_default + return metadata_default else: return None @@ -1190,7 +1184,7 @@ def _compare_server_default( ) else: rendered_metadata_default = _render_server_default_for_compare( - metadata_default, metadata_col, autogen_context + metadata_default, autogen_context ) rendered_conn_default = ( diff --git a/alembic/context.pyi b/alembic/context.pyi index c81a14fd..621599d3 100644 --- a/alembic/context.pyi +++ b/alembic/context.pyi @@ -22,6 +22,8 @@ if TYPE_CHECKING: from sqlalchemy.engine.base import Connection from sqlalchemy.engine.url import URL from sqlalchemy.sql.elements import ClauseElement + from sqlalchemy.sql.schema import Column + from sqlalchemy.sql.schema import FetchedValue from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.schema import SchemaItem @@ -144,7 +146,20 @@ def configure( ] ] = None, compare_type: bool = False, - compare_server_default: bool = False, + compare_server_default: Union[ + bool, + Callable[ + [ + MigrationContext, + Column, + Column, + Optional[str], + Optional[FetchedValue], + Optional[str], + ], + Optional[bool], + ], + ] = False, render_item: Optional[ Callable[[str, Any, AutogenContext], Union[str, Literal[False]]] ] = None, diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py index a4527602..5e66f538 100644 --- a/alembic/ddl/mysql.py +++ b/alembic/ddl/mysql.py @@ -185,13 +185,22 @@ class MySQLImpl(DefaultImpl): and rendered_inspector_default == "'0'" ): return False - elif inspector_column.type._type_affinity is sqltypes.Integer: + elif ( + rendered_inspector_default + and inspector_column.type._type_affinity is sqltypes.Integer + ): rendered_inspector_default = ( re.sub(r"^'|'$", "", rendered_inspector_default) if rendered_inspector_default is not None else None ) return rendered_inspector_default != rendered_metadata_default + elif ( + rendered_metadata_default + and metadata_column.type._type_affinity is sqltypes.String + ): + metadata_default = re.sub(r"^'|'$", "", rendered_metadata_default) + return rendered_inspector_default != f"'{metadata_default}'" elif rendered_inspector_default and rendered_metadata_default: # adjust for "function()" vs. "FUNCTION" as can occur particularly # for the CURRENT_TIMESTAMP function on newer MariaDB versions diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index 71a53091..30873773 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -15,6 +15,8 @@ from typing import Tuple from typing import TYPE_CHECKING from typing import Union +from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.schema import FetchedValue from typing_extensions import Literal from .migration import _ProxyTransaction @@ -79,6 +81,18 @@ OnVersionApplyFn = Callable[ None, ] +CompareServerDefault = Callable[ + [ + MigrationContext, + Column, + Column, + Optional[str], + Optional[FetchedValue], + Optional[str], + ], + Optional[bool], +] + class EnvironmentContext(util.ModuleClsProxy): @@ -398,7 +412,7 @@ class EnvironmentContext(util.ModuleClsProxy): ProcessRevisionDirectiveFn ] = None, compare_type: bool = False, - compare_server_default: bool = False, + compare_server_default: Union[bool, CompareServerDefault] = False, render_item: Optional[RenderItemFn] = None, literal_binds: bool = False, upgrade_token: str = "upgrades", diff --git a/docs/build/unreleased/1178.rst b/docs/build/unreleased/1178.rst new file mode 100644 index 00000000..25789d3c --- /dev/null +++ b/docs/build/unreleased/1178.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: changed, autogenerate + :tickets: 1178 + + Don't modify the metadata server default when comparing it in the + autogenerate process. + This impacts the value passes to user provided functions passed in + :paramref:`.EnvironmentContext.configure.compare_server_default` + and third party dialect that implement a custom ``compare_server_default``. diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 77ed4dab..818dae7c 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -846,7 +846,7 @@ class PostgresqlDetectSerialTest(TestBase): eq_( _render_server_default_for_compare( - tab.c.x.server_default, tab.c.x, self.autogen_context + tab.c.x.server_default, self.autogen_context ), c_expected, ) @@ -867,7 +867,7 @@ class PostgresqlDetectSerialTest(TestBase): server_default = diffs[0][0][4]["existing_server_default"] eq_( _render_server_default_for_compare( - server_default, tab.c.x, self.autogen_context + server_default, self.autogen_context ), c_expected, )