]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Change single-quoting of floats in PostgreSQL compare_server_default
authorDimitris Theodorou <dimitris.theodorou@gmail.com>
Sun, 11 Jan 2015 23:41:59 +0000 (00:41 +0100)
committerDimitris Theodorou <dimitris.theodorou@gmail.com>
Sun, 11 Jan 2015 23:41:59 +0000 (00:41 +0100)
Do not wrap string defaults with single quotes when comparing against
columns of type float or numeric.
This fixes the crash occuring when the default of a float column is
an integer value (e.g., DEFAULT 5), while the Python server_default is
a string (e.g., server_default="5.0"). This results in the query
used in the comparison to throw a DataError ('SELECT 5 = '5.0').

alembic/ddl/postgresql.py
tests/test_postgresql.py

index 0877c959e646ddf328b558a043ecd5f666fefc48..4c6e9d79fd57ff5ef0dd67c7fb8cc08ebd6b1093 100644 (file)
@@ -4,7 +4,7 @@ from .. import compat
 from .base import compiles, alter_table, format_table_name, RenameTable
 from .impl import DefaultImpl
 from sqlalchemy.dialects.postgresql import INTEGER, BIGINT
-from sqlalchemy import text
+from sqlalchemy import text, Float, Numeric
 import logging
 
 log = logging.getLogger(__name__)
@@ -35,7 +35,10 @@ class PostgresqlImpl(DefaultImpl):
         if metadata_column.server_default is not None and \
             isinstance(metadata_column.server_default.arg,
                        compat.string_types) and \
-                not re.match(r"^'.+'$", rendered_metadata_default):
+                not re.match(r"^'.+'$", rendered_metadata_default) and \
+                not isinstance(inspector_column.type, (Float, Numeric)):
+                # don't single quote if the column type is float/numeric,
+                # otherwise a comparison such as SELECT 5 = '5.0' will fail
             rendered_metadata_default = "'%s'" % rendered_metadata_default
 
         return not self.connection.scalar(
index 908eec6db3c66fd896914c8675b96bd35453dd17..e70d05a3b0efc79f12416d90ffc7ac273d1b70ac 100644 (file)
@@ -1,6 +1,6 @@
 
 from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, \
-    String, Interval, Sequence, Numeric, BigInteger
+    String, Interval, Sequence, Numeric, BigInteger, Float, Numeric
 from sqlalchemy.dialects.postgresql import ARRAY
 from sqlalchemy.engine.reflection import Inspector
 from alembic.operations import Operations
@@ -193,8 +193,11 @@ class PostgresqlDefaultCompareTest(TestBase):
     def tearDown(self):
         self.metadata.drop_all()
 
-    def _compare_default_roundtrip(self, type_, orig_default, alternate=None):
-        diff_expected = alternate is not None
+    def _compare_default_roundtrip(
+            self, type_, orig_default, alternate=None, diff_expected=None):
+        diff_expected = diff_expected \
+            if diff_expected is not None \
+            else alternate is not None
         if alternate is None:
             alternate = orig_default
 
@@ -274,6 +277,67 @@ class PostgresqlDefaultCompareTest(TestBase):
             text("5"), "7"
         )
 
+    def test_compare_float_str(self):
+        self._compare_default_roundtrip(
+            Float(),
+            "5.2",
+        )
+
+    def test_compare_float_text(self):
+        self._compare_default_roundtrip(
+            Float(),
+            text("5.2"),
+        )
+
+    def test_compare_float_no_diff1(self):
+        self._compare_default_roundtrip(
+            Float(),
+            text("5.2"), "5.2",
+            diff_expected=False
+        )
+
+    def test_compare_float_no_diff2(self):
+        self._compare_default_roundtrip(
+            Float(),
+            "5.2", text("5.2"),
+            diff_expected=False
+        )
+
+    def test_compare_float_no_diff3(self):
+        self._compare_default_roundtrip(
+            Float(),
+            text("5"), text("5.0"),
+            diff_expected=False
+        )
+
+    def test_compare_float_no_diff4(self):
+        self._compare_default_roundtrip(
+            Float(),
+            "5", "5.0",
+            diff_expected=False
+        )
+
+    def test_compare_float_no_diff5(self):
+        self._compare_default_roundtrip(
+            Float(),
+            text("5"), "5.0",
+            diff_expected=False
+        )
+
+    def test_compare_float_no_diff6(self):
+        self._compare_default_roundtrip(
+            Float(),
+            "5", text("5.0"),
+            diff_expected=False
+        )
+
+    def test_compare_numeric_no_diff(self):
+        self._compare_default_roundtrip(
+            Numeric(),
+            text("5"), "5.0",
+            diff_expected=False
+        )
+
     def test_compare_character_str(self):
         self._compare_default_roundtrip(
             String(),