]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- Some deep-in-the-weeds fixes to try to get "server default" comparison
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Jun 2014 17:49:44 +0000 (13:49 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Jun 2014 17:49:44 +0000 (13:49 -0400)
working better across platforms and expressions, in particular on
the Postgresql backend, mostly dealing with quoting/not quoting of various
expressions at the appropriate time and on a per-backend basis.
Repaired and tested support for such defaults as Postgresql interval
and array defaults.
fixes #212

alembic/autogenerate/compare.py
alembic/autogenerate/render.py
alembic/ddl/postgresql.py
alembic/ddl/sqlite.py
docs/build/changelog.rst
tests/test_postgresql.py

index ec077fdd3b98cff3b694f6cf78b236cdfbfa7de1..0d58bec554e76339a133be1f08c453eb19c41a6b 100644 (file)
@@ -454,6 +454,12 @@ def _compare_type(schema, tname, cname, conn_col,
             conn_type, metadata_type, tname, cname
         )
 
+def _render_server_default_for_compare(metadata_default,
+                                        metadata_col, autogen_context):
+    return _render_server_default(
+                    metadata_default, autogen_context,
+                    repr_=metadata_col.type._type_affinity is sqltypes.String)
+
 def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
                                 diffs, autogen_context):
 
@@ -461,8 +467,9 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
     conn_col_default = conn_col.server_default
     if conn_col_default is None and metadata_default is None:
         return False
-    rendered_metadata_default = _render_server_default(
-                            metadata_default, autogen_context)
+    rendered_metadata_default = _render_server_default_for_compare(
+                    metadata_default, metadata_col, autogen_context)
+
     rendered_conn_default = conn_col.server_default.arg.text \
                             if conn_col.server_default else None
     isdiff = autogen_context['context']._compare_server_default(
index ed9536c79642c614d3bc3c410547ca4996ba8eec..3828d87f0f7a95094bddaee478dd46859a239478 100644 (file)
@@ -307,7 +307,7 @@ def _render_column(column, autogen_context):
         'kw': ", ".join(["%s=%s" % (kwname, val) for kwname, val in opts])
     }
 
-def _render_server_default(default, autogen_context):
+def _render_server_default(default, autogen_context, repr_=True):
     rendered = _user_defined_render("server_default", default, autogen_context)
     if rendered is not False:
         return rendered
@@ -319,11 +319,11 @@ def _render_server_default(default, autogen_context):
             default = str(default.arg.compile(
                             dialect=autogen_context['dialect']))
     if isinstance(default, string_types):
-        # TODO: this is just a hack to get
-        # tests to pass until we figure out
-        # WTF sqlite is doing
-        default = re.sub(r"^'|'$", "", default)
-        return repr(default)
+        if repr_:
+            default = re.sub(r"^'|'$", "", default)
+            return repr(default)
+        else:
+            return default
     else:
         return None
 
index 5ca0d1f592a8e07956d102cefefe98106ca61d91..27f31b0ee2f34c721721d69b4642c4e3feb0664f 100644 (file)
@@ -1,7 +1,7 @@
 import re
 
 from sqlalchemy import types as sqltypes
-
+from .. import compat
 from .base import compiles, alter_table, format_table_name, RenameTable
 from .impl import DefaultImpl
 
@@ -24,8 +24,11 @@ class PostgresqlImpl(DefaultImpl):
         if None in (conn_col_default, rendered_metadata_default):
             return conn_col_default != rendered_metadata_default
 
-        if metadata_column.type._type_affinity is not sqltypes.String:
-            rendered_metadata_default = re.sub(r"^'|'$", "", rendered_metadata_default)
+        if metadata_column.server_default is not None and \
+            isinstance(metadata_column.server_default.arg,
+                        compat.string_types) and \
+                not re.match(r"^'.+'$", rendered_metadata_default):
+            rendered_metadata_default = "'%s'" % rendered_metadata_default
 
         return not self.connection.scalar(
             "SELECT %s = %s" % (
index a3c73ce5355832b041649e0b971bcba781c4e752..85c829e3c7194610696c7972709a65df71348117 100644 (file)
@@ -1,5 +1,6 @@
 from .. import util
 from .impl import DefaultImpl
+import re
 
 #from sqlalchemy.ext.compiler import compiles
 #from .base import AddColumn, alter_table
@@ -29,6 +30,14 @@ class SQLiteImpl(DefaultImpl):
             raise NotImplementedError(
                     "No support for ALTER of constraints in SQLite dialect")
 
+    def compare_server_default(self, inspector_column,
+                            metadata_column,
+                            rendered_metadata_default,
+                            rendered_inspector_default):
+
+        rendered_metadata_default = re.sub(r"^'|'$", "", rendered_metadata_default)
+        return rendered_inspector_default != repr(rendered_metadata_default)
+
     def correct_for_autogen_constraints(self, conn_unique_constraints, conn_indexes,
                                         metadata_unique_constraints,
                                         metadata_indexes):
index 2d84283cb219239673ce37fba60dab1f026e5a13..33053e2c3b91c034c6c57d5761c9111e5848ba6b 100644 (file)
@@ -5,6 +5,17 @@ Changelog
 .. changelog::
     :version: 0.6.6
 
+    .. change::
+      :tags: bug
+      :tickets: 212
+
+      Some deep-in-the-weeds fixes to try to get "server default" comparison
+      working better across platforms and expressions, in particular on
+      the Postgresql backend, mostly dealing with quoting/not quoting of various
+      expressions at the appropriate time and on a per-backend basis.
+      Repaired and tested support for such defaults as Postgresql interval
+      and array defaults.
+
     .. change::
       :tags: enhancement
       :tickets: 209
index bd1c0b4ee731d278e070be48a2047164cd81b47d..2e0965e7bc7ff33455cabf0f14bc0fbe1d46051f 100644 (file)
@@ -1,9 +1,13 @@
 from unittest import TestCase
 
-from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, String
+from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, \
+        String, Interval
+from sqlalchemy.dialects.postgresql import ARRAY
+from sqlalchemy.schema import DefaultClause
 from sqlalchemy.engine.reflection import Inspector
 from alembic.operations import Operations
 from sqlalchemy.sql import table, column
+from alembic.autogenerate.compare import _compare_server_default
 
 from alembic import command, util
 from alembic.migration import MigrationContext
@@ -158,7 +162,7 @@ class PostgresqlDefaultCompareTest(TestCase):
             'connection': connection,
             'dialect': connection.dialect,
             'context': context
-            }
+        }
 
     @classmethod
     def teardown_class(cls):
@@ -170,60 +174,111 @@ class PostgresqlDefaultCompareTest(TestCase):
     def tearDown(self):
         self.metadata.drop_all()
 
-    def _compare_default_roundtrip(self, type_, txt, alternate=None):
-        if alternate:
-            expected = True
-        else:
-            alternate = txt
-            expected = False
-        t = Table("test", self.metadata,
-            Column("somecol", type_, server_default=text(txt))
-        )
+    def _compare_default_roundtrip(self, type_, orig_default, alternate=None):
+        diff_expected = alternate is not None
+        if alternate is None:
+            alternate = orig_default
+
+        t1 = Table("test", self.metadata,
+            Column("somecol", type_, server_default=orig_default))
         t2 = Table("test", MetaData(),
-            Column("somecol", type_, server_default=text(alternate))
-        )
-        assert self._compare_default(
-            t, t2, t2.c.somecol, alternate
-        ) is expected
+            Column("somecol", type_, server_default=alternate))
+
+        t1.create(self.bind)
+
+        insp = Inspector.from_engine(self.bind)
+        cols = insp.get_columns(t1.name)
+        insp_col = Column("somecol", cols[0]['type'],
+                                server_default=text(cols[0]['default']))
+        diffs = []
+        _compare_server_default(None, "test", "somecol", insp_col,
+                t2.c.somecol, diffs, self.autogen_context)
+        eq_(bool(diffs), diff_expected)
 
     def _compare_default(
         self,
         t1, t2, col,
         rendered
     ):
-        t1.create(self.bind)
+        t1.create(self.bind, checkfirst=True)
         insp = Inspector.from_engine(self.bind)
         cols = insp.get_columns(t1.name)
         ctx = self.autogen_context['context']
+
         return ctx.impl.compare_server_default(
             None,
             col,
             rendered,
             cols[0]['default'])
 
-    def test_compare_current_timestamp(self):
+    def test_compare_interval_str(self):
+        # this form shouldn't be used but testing here
+        # for compatibility
+        self._compare_default_roundtrip(
+            Interval,
+            "14 days"
+        )
+
+    def test_compare_interval_text(self):
+        self._compare_default_roundtrip(
+            Interval,
+            text("'14 days'")
+        )
+
+    def test_compare_array_of_integer_text(self):
+        self._compare_default_roundtrip(
+            ARRAY(Integer),
+            text("(ARRAY[]::integer[])")
+        )
+
+    def test_compare_current_timestamp_text(self):
         self._compare_default_roundtrip(
             DateTime(),
-            "TIMEZONE('utc', CURRENT_TIMESTAMP)",
+            text("TIMEZONE('utc', CURRENT_TIMESTAMP)"),
         )
 
-    def test_compare_integer(self):
+    def test_compare_integer_str(self):
         self._compare_default_roundtrip(
             Integer(),
             "5",
         )
 
-    def test_compare_integer_diff(self):
+    def test_compare_integer_text(self):
         self._compare_default_roundtrip(
             Integer(),
-            "5", "7"
+            text("5"),
+        )
+
+    def test_compare_integer_text_diff(self):
+        self._compare_default_roundtrip(
+            Integer(),
+            text("5"), "7"
+        )
+
+    def test_compare_character_str(self):
+        self._compare_default_roundtrip(
+            String(),
+            "hello",
+        )
+
+    def test_compare_character_text(self):
+        self._compare_default_roundtrip(
+            String(),
+            text("'hello'"),
+        )
+
+    def test_compare_character_str_diff(self):
+        self._compare_default_roundtrip(
+            String(),
+            "hello",
+            "there"
         )
 
-    def test_compare_character_diff(self):
+    def test_compare_character_text_diff(self):
         self._compare_default_roundtrip(
             String(),
-            "'hello'",
-            "'there'"
+            text("'hello'"),
+            text("'there'")
         )
 
     def test_primary_key_skip(self):