From: Dimitris Theodorou Date: Sun, 11 Jan 2015 23:41:59 +0000 (+0100) Subject: Change single-quoting of floats in PostgreSQL compare_server_default X-Git-Tag: rel_0_7_4~3^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=214228a9dd0fa307fc9d0ebfd6a52390cd63f6f7;p=thirdparty%2Fsqlalchemy%2Falembic.git Change single-quoting of floats in PostgreSQL compare_server_default Do not wrap string defaults with single quotes when comparing against columns of type float or numeric. This fixes the crash occuring when the default of a float column is an integer value (e.g., DEFAULT 5), while the Python server_default is a string (e.g., server_default="5.0"). This results in the query used in the comparison to throw a DataError ('SELECT 5 = '5.0'). --- diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index 0877c959..4c6e9d79 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -4,7 +4,7 @@ from .. import compat from .base import compiles, alter_table, format_table_name, RenameTable from .impl import DefaultImpl from sqlalchemy.dialects.postgresql import INTEGER, BIGINT -from sqlalchemy import text +from sqlalchemy import text, Float, Numeric import logging log = logging.getLogger(__name__) @@ -35,7 +35,10 @@ class PostgresqlImpl(DefaultImpl): if metadata_column.server_default is not None and \ isinstance(metadata_column.server_default.arg, compat.string_types) and \ - not re.match(r"^'.+'$", rendered_metadata_default): + not re.match(r"^'.+'$", rendered_metadata_default) and \ + not isinstance(inspector_column.type, (Float, Numeric)): + # don't single quote if the column type is float/numeric, + # otherwise a comparison such as SELECT 5 = '5.0' will fail rendered_metadata_default = "'%s'" % rendered_metadata_default return not self.connection.scalar( diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 908eec6d..e70d05a3 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1,6 +1,6 @@ from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, \ - String, Interval, Sequence, Numeric, BigInteger + String, Interval, Sequence, Numeric, BigInteger, Float, Numeric from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.engine.reflection import Inspector from alembic.operations import Operations @@ -193,8 +193,11 @@ class PostgresqlDefaultCompareTest(TestBase): def tearDown(self): self.metadata.drop_all() - def _compare_default_roundtrip(self, type_, orig_default, alternate=None): - diff_expected = alternate is not None + def _compare_default_roundtrip( + self, type_, orig_default, alternate=None, diff_expected=None): + diff_expected = diff_expected \ + if diff_expected is not None \ + else alternate is not None if alternate is None: alternate = orig_default @@ -274,6 +277,67 @@ class PostgresqlDefaultCompareTest(TestBase): text("5"), "7" ) + def test_compare_float_str(self): + self._compare_default_roundtrip( + Float(), + "5.2", + ) + + def test_compare_float_text(self): + self._compare_default_roundtrip( + Float(), + text("5.2"), + ) + + def test_compare_float_no_diff1(self): + self._compare_default_roundtrip( + Float(), + text("5.2"), "5.2", + diff_expected=False + ) + + def test_compare_float_no_diff2(self): + self._compare_default_roundtrip( + Float(), + "5.2", text("5.2"), + diff_expected=False + ) + + def test_compare_float_no_diff3(self): + self._compare_default_roundtrip( + Float(), + text("5"), text("5.0"), + diff_expected=False + ) + + def test_compare_float_no_diff4(self): + self._compare_default_roundtrip( + Float(), + "5", "5.0", + diff_expected=False + ) + + def test_compare_float_no_diff5(self): + self._compare_default_roundtrip( + Float(), + text("5"), "5.0", + diff_expected=False + ) + + def test_compare_float_no_diff6(self): + self._compare_default_roundtrip( + Float(), + "5", text("5.0"), + diff_expected=False + ) + + def test_compare_numeric_no_diff(self): + self._compare_default_roundtrip( + Numeric(), + text("5"), "5.0", + diff_expected=False + ) + def test_compare_character_str(self): self._compare_default_roundtrip( String(),