From: Mike Bayer Date: Mon, 21 Oct 2019 21:32:04 +0000 (-0400) Subject: Refactor dialect tests for combinations X-Git-Tag: rel_1_3_11~26^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bf905106bf4c6c4ee7a31a3ffc3da79c54fcca97;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Refactor dialect tests for combinations Dialect tests tend to have a lot of lists of types, SQL constructs etc, convert as many of these to @combinations as possible. This is exposing that we don't have per-combination exclusion rules set up which is making things a little bit cumbersome. Also set up a fixture that does metadata + DDL. Change-Id: Ief820e48c9202982b0b1e181b87862490cd7b0c3 (cherry picked from commit 240d9a60ccdb540543a72d9ff30a6f50d33acc5d) --- diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 0d469c23d7..5e303be00a 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -52,6 +52,7 @@ from .exclusions import skip_if # noqa from .util import adict # noqa from .util import fail # noqa from .util import force_drop_names # noqa +from .util import metadata_fixture # noqa from .util import provide_metadata # noqa from .util import rowset # noqa from .util import run_as_contextmanager # noqa diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index c29586a670..c5323921af 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -719,7 +719,7 @@ class FixtureFunctions(ABC): raise NotImplementedError() @abc.abstractmethod - def fixture(self, fn): + def fixture(self, *arg, **kw): raise NotImplementedError() diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 903b2a5e63..3c47cbce84 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -382,5 +382,5 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): ident = parameters[0] return pytest.param(*parameters[1:], id=ident) - def fixture(self, fn): - return pytest.fixture(fn) + def fixture(self, *arg, **kw): + return pytest.fixture(*arg, **kw) diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 219511ea0e..64738ad2f6 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -211,6 +211,30 @@ def provide_metadata(fn, *args, **kw): self.metadata = prev_meta +def metadata_fixture(ddl="function"): + """Provide MetaData for a pytest fixture.""" + + from . import config + + def decorate(fn): + def run_ddl(self): + from sqlalchemy import schema + + metadata = self.metadata = schema.MetaData() + try: + result = fn(self, metadata) + metadata.create_all(config.db) + # TODO: + # somehow get a per-function dml erase fixture here + yield result + finally: + metadata.drop_all(config.db) + + return config.fixture(scope=ddl)(run_ddl) + + return decorate + + def force_drop_names(*names): """Force the given table names to be dropped after test complete, isolating for foreign key cycles diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 6355f60c38..301562d1c6 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -37,6 +37,7 @@ from sqlalchemy import SmallInteger from sqlalchemy import sql from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy import testing from sqlalchemy import TEXT from sqlalchemy import TIME from sqlalchemy import Time @@ -454,32 +455,32 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): {"param_1": 10}, ) - def test_varchar_raise(self): - for type_ in ( - String, - VARCHAR, - String(), - VARCHAR(), - NVARCHAR(), - Unicode, - Unicode(), - ): - type_ = sqltypes.to_instance(type_) - assert_raises_message( - exc.CompileError, - "VARCHAR requires a length on dialect mysql", - type_.compile, - dialect=mysql.dialect(), - ) + @testing.combinations( + (String,), + (VARCHAR,), + (String(),), + (VARCHAR(),), + (NVARCHAR(),), + (Unicode,), + (Unicode(),), + ) + def test_varchar_raise(self, type_): + type_ = sqltypes.to_instance(type_) + assert_raises_message( + exc.CompileError, + "VARCHAR requires a length on dialect mysql", + type_.compile, + dialect=mysql.dialect(), + ) - t1 = Table("sometable", MetaData(), Column("somecolumn", type_)) - assert_raises_message( - exc.CompileError, - r"\(in table 'sometable', column 'somecolumn'\)\: " - r"(?:N)?VARCHAR requires a length on dialect mysql", - schema.CreateTable(t1).compile, - dialect=mysql.dialect(), - ) + t1 = Table("sometable", MetaData(), Column("somecolumn", type_)) + assert_raises_message( + exc.CompileError, + r"\(in table 'sometable', column 'somecolumn'\)\: " + r"(?:N)?VARCHAR requires a length on dialect mysql", + schema.CreateTable(t1).compile, + dialect=mysql.dialect(), + ) def test_update_limit(self): t = sql.table("t", sql.column("col1"), sql.column("col2")) @@ -513,75 +514,73 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): def test_sysdate(self): self.assert_compile(func.sysdate(), "SYSDATE()") - def test_cast(self): + m = mysql + + @testing.combinations( + (Integer, "CAST(t.col AS SIGNED INTEGER)"), + (INT, "CAST(t.col AS SIGNED INTEGER)"), + (m.MSInteger, "CAST(t.col AS SIGNED INTEGER)"), + (m.MSInteger(unsigned=True), "CAST(t.col AS UNSIGNED INTEGER)"), + (SmallInteger, "CAST(t.col AS SIGNED INTEGER)"), + (m.MSSmallInteger, "CAST(t.col AS SIGNED INTEGER)"), + (m.MSTinyInteger, "CAST(t.col AS SIGNED INTEGER)"), + # 'SIGNED INTEGER' is a bigint, so this is ok. + (m.MSBigInteger, "CAST(t.col AS SIGNED INTEGER)"), + (m.MSBigInteger(unsigned=False), "CAST(t.col AS SIGNED INTEGER)"), + (m.MSBigInteger(unsigned=True), "CAST(t.col AS UNSIGNED INTEGER)"), + # this is kind of sucky. thank you default arguments! + (NUMERIC, "CAST(t.col AS DECIMAL)"), + (DECIMAL, "CAST(t.col AS DECIMAL)"), + (Numeric, "CAST(t.col AS DECIMAL)"), + (m.MSNumeric, "CAST(t.col AS DECIMAL)"), + (m.MSDecimal, "CAST(t.col AS DECIMAL)"), + (TIMESTAMP, "CAST(t.col AS DATETIME)"), + (DATETIME, "CAST(t.col AS DATETIME)"), + (DATE, "CAST(t.col AS DATE)"), + (TIME, "CAST(t.col AS TIME)"), + (DateTime, "CAST(t.col AS DATETIME)"), + (Date, "CAST(t.col AS DATE)"), + (Time, "CAST(t.col AS TIME)"), + (DateTime, "CAST(t.col AS DATETIME)"), + (Date, "CAST(t.col AS DATE)"), + (m.MSTime, "CAST(t.col AS TIME)"), + (m.MSTimeStamp, "CAST(t.col AS DATETIME)"), + (String, "CAST(t.col AS CHAR)"), + (Unicode, "CAST(t.col AS CHAR)"), + (UnicodeText, "CAST(t.col AS CHAR)"), + (VARCHAR, "CAST(t.col AS CHAR)"), + (NCHAR, "CAST(t.col AS CHAR)"), + (CHAR, "CAST(t.col AS CHAR)"), + (m.CHAR(charset="utf8"), "CAST(t.col AS CHAR CHARACTER SET utf8)"), + (CLOB, "CAST(t.col AS CHAR)"), + (TEXT, "CAST(t.col AS CHAR)"), + (m.TEXT(charset="utf8"), "CAST(t.col AS CHAR CHARACTER SET utf8)"), + (String(32), "CAST(t.col AS CHAR(32))"), + (Unicode(32), "CAST(t.col AS CHAR(32))"), + (CHAR(32), "CAST(t.col AS CHAR(32))"), + (m.MSString, "CAST(t.col AS CHAR)"), + (m.MSText, "CAST(t.col AS CHAR)"), + (m.MSTinyText, "CAST(t.col AS CHAR)"), + (m.MSMediumText, "CAST(t.col AS CHAR)"), + (m.MSLongText, "CAST(t.col AS CHAR)"), + (m.MSNChar, "CAST(t.col AS CHAR)"), + (m.MSNVarChar, "CAST(t.col AS CHAR)"), + (LargeBinary, "CAST(t.col AS BINARY)"), + (BLOB, "CAST(t.col AS BINARY)"), + (m.MSBlob, "CAST(t.col AS BINARY)"), + (m.MSBlob(32), "CAST(t.col AS BINARY)"), + (m.MSTinyBlob, "CAST(t.col AS BINARY)"), + (m.MSMediumBlob, "CAST(t.col AS BINARY)"), + (m.MSLongBlob, "CAST(t.col AS BINARY)"), + (m.MSBinary, "CAST(t.col AS BINARY)"), + (m.MSBinary(32), "CAST(t.col AS BINARY)"), + (m.MSVarBinary, "CAST(t.col AS BINARY)"), + (m.MSVarBinary(32), "CAST(t.col AS BINARY)"), + (Interval, "CAST(t.col AS DATETIME)"), + ) + def test_cast(self, type_, expected): t = sql.table("t", sql.column("col")) - m = mysql - - specs = [ - (Integer, "CAST(t.col AS SIGNED INTEGER)"), - (INT, "CAST(t.col AS SIGNED INTEGER)"), - (m.MSInteger, "CAST(t.col AS SIGNED INTEGER)"), - (m.MSInteger(unsigned=True), "CAST(t.col AS UNSIGNED INTEGER)"), - (SmallInteger, "CAST(t.col AS SIGNED INTEGER)"), - (m.MSSmallInteger, "CAST(t.col AS SIGNED INTEGER)"), - (m.MSTinyInteger, "CAST(t.col AS SIGNED INTEGER)"), - # 'SIGNED INTEGER' is a bigint, so this is ok. - (m.MSBigInteger, "CAST(t.col AS SIGNED INTEGER)"), - (m.MSBigInteger(unsigned=False), "CAST(t.col AS SIGNED INTEGER)"), - (m.MSBigInteger(unsigned=True), "CAST(t.col AS UNSIGNED INTEGER)"), - # this is kind of sucky. thank you default arguments! - (NUMERIC, "CAST(t.col AS DECIMAL)"), - (DECIMAL, "CAST(t.col AS DECIMAL)"), - (Numeric, "CAST(t.col AS DECIMAL)"), - (m.MSNumeric, "CAST(t.col AS DECIMAL)"), - (m.MSDecimal, "CAST(t.col AS DECIMAL)"), - (TIMESTAMP, "CAST(t.col AS DATETIME)"), - (DATETIME, "CAST(t.col AS DATETIME)"), - (DATE, "CAST(t.col AS DATE)"), - (TIME, "CAST(t.col AS TIME)"), - (DateTime, "CAST(t.col AS DATETIME)"), - (Date, "CAST(t.col AS DATE)"), - (Time, "CAST(t.col AS TIME)"), - (DateTime, "CAST(t.col AS DATETIME)"), - (Date, "CAST(t.col AS DATE)"), - (m.MSTime, "CAST(t.col AS TIME)"), - (m.MSTimeStamp, "CAST(t.col AS DATETIME)"), - (String, "CAST(t.col AS CHAR)"), - (Unicode, "CAST(t.col AS CHAR)"), - (UnicodeText, "CAST(t.col AS CHAR)"), - (VARCHAR, "CAST(t.col AS CHAR)"), - (NCHAR, "CAST(t.col AS CHAR)"), - (CHAR, "CAST(t.col AS CHAR)"), - (m.CHAR(charset="utf8"), "CAST(t.col AS CHAR CHARACTER SET utf8)"), - (CLOB, "CAST(t.col AS CHAR)"), - (TEXT, "CAST(t.col AS CHAR)"), - (m.TEXT(charset="utf8"), "CAST(t.col AS CHAR CHARACTER SET utf8)"), - (String(32), "CAST(t.col AS CHAR(32))"), - (Unicode(32), "CAST(t.col AS CHAR(32))"), - (CHAR(32), "CAST(t.col AS CHAR(32))"), - (m.MSString, "CAST(t.col AS CHAR)"), - (m.MSText, "CAST(t.col AS CHAR)"), - (m.MSTinyText, "CAST(t.col AS CHAR)"), - (m.MSMediumText, "CAST(t.col AS CHAR)"), - (m.MSLongText, "CAST(t.col AS CHAR)"), - (m.MSNChar, "CAST(t.col AS CHAR)"), - (m.MSNVarChar, "CAST(t.col AS CHAR)"), - (LargeBinary, "CAST(t.col AS BINARY)"), - (BLOB, "CAST(t.col AS BINARY)"), - (m.MSBlob, "CAST(t.col AS BINARY)"), - (m.MSBlob(32), "CAST(t.col AS BINARY)"), - (m.MSTinyBlob, "CAST(t.col AS BINARY)"), - (m.MSMediumBlob, "CAST(t.col AS BINARY)"), - (m.MSLongBlob, "CAST(t.col AS BINARY)"), - (m.MSBinary, "CAST(t.col AS BINARY)"), - (m.MSBinary(32), "CAST(t.col AS BINARY)"), - (m.MSVarBinary, "CAST(t.col AS BINARY)"), - (m.MSVarBinary(32), "CAST(t.col AS BINARY)"), - (Interval, "CAST(t.col AS DATETIME)"), - ] - - for type_, expected in specs: - self.assert_compile(cast(t.c.col, type_), expected) + self.assert_compile(cast(t.c.col, type_), expected) def test_cast_type_decorator(self): class MyInteger(sqltypes.TypeDecorator): @@ -618,33 +617,29 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): "(foo + 5)", ) - def test_unsupported_casts(self): + m = mysql + + @testing.combinations( + (m.MSBit, "t.col"), + (FLOAT, "t.col"), + (Float, "t.col"), + (m.MSFloat, "t.col"), + (m.MSDouble, "t.col"), + (m.MSReal, "t.col"), + (m.MSYear, "t.col"), + (m.MSYear(2), "t.col"), + (Boolean, "t.col"), + (BOOLEAN, "t.col"), + (m.MSEnum, "t.col"), + (m.MSEnum("1", "2"), "t.col"), + (m.MSSet, "t.col"), + (m.MSSet("1", "2"), "t.col"), + ) + def test_unsupported_casts(self, type_, expected): t = sql.table("t", sql.column("col")) - m = mysql - - specs = [ - (m.MSBit, "t.col"), - (FLOAT, "t.col"), - (Float, "t.col"), - (m.MSFloat, "t.col"), - (m.MSDouble, "t.col"), - (m.MSReal, "t.col"), - (m.MSYear, "t.col"), - (m.MSYear(2), "t.col"), - (Boolean, "t.col"), - (BOOLEAN, "t.col"), - (m.MSEnum, "t.col"), - (m.MSEnum("1", "2"), "t.col"), - (m.MSSet, "t.col"), - (m.MSSet("1", "2"), "t.col"), - ] - - for type_, expected in specs: - with expect_warnings( - "Datatype .* does not support CAST on MySQL;" - ): - self.assert_compile(cast(t.c.col, type_), expected) + with expect_warnings("Datatype .* does not support CAST on MySQL;"): + self.assert_compile(cast(t.c.col, type_), expected) def test_no_cast_pre_4(self): self.assert_compile( diff --git a/test/dialect/mysql/test_dialect.py b/test/dialect/mysql/test_dialect.py index 3005e459b9..26c7d38b8c 100644 --- a/test/dialect/mysql/test_dialect.py +++ b/test/dialect/mysql/test_dialect.py @@ -61,31 +61,28 @@ class DialectTest(fixtures.TestBase): }, ) - def test_normal_arguments_mysqldb(self): + @testing.combinations( + ("compress", True), + ("connect_timeout", 30), + ("read_timeout", 30), + ("write_timeout", 30), + ("client_flag", 1234), + ("local_infile", 1234), + ("use_unicode", False), + ("charset", "hello"), + ) + def test_normal_arguments_mysqldb(self, kwarg, value): from sqlalchemy.dialects.mysql import mysqldb dialect = mysqldb.dialect() - self._test_normal_arguments(dialect) - - def _test_normal_arguments(self, dialect): - for kwarg, value in [ - ("compress", True), - ("connect_timeout", 30), - ("read_timeout", 30), - ("write_timeout", 30), - ("client_flag", 1234), - ("local_infile", 1234), - ("use_unicode", False), - ("charset", "hello"), - ]: - connect_args = dialect.create_connect_args( - make_url( - "mysql://scott:tiger@localhost:3306/test" - "?%s=%s" % (kwarg, value) - ) + connect_args = dialect.create_connect_args( + make_url( + "mysql://scott:tiger@localhost:3306/test" + "?%s=%s" % (kwarg, value) ) + ) - eq_(connect_args[1][kwarg], value) + eq_(connect_args[1][kwarg], value) def test_mysqlconnector_buffered_arg(self): from sqlalchemy.dialects.mysql import mysqlconnector @@ -191,57 +188,58 @@ class DialectTest(fixtures.TestBase): class ParseVersionTest(fixtures.TestBase): - def test_mariadb_normalized_version(self): - for expected, raw_version, version, is_mariadb in [ - ((10, 2, 7), "10.2.7-MariaDB", (10, 2, 7, "MariaDB"), True), - ( - (10, 2, 7), - "5.6.15.10.2.7-MariaDB", - (5, 6, 15, 10, 2, 7, "MariaDB"), - True, - ), - ((10, 2, 10), "10.2.10-MariaDB", (10, 2, 10, "MariaDB"), True), - ((5, 7, 20), "5.7.20", (5, 7, 20), False), - ((5, 6, 15), "5.6.15", (5, 6, 15), False), - ( - (10, 2, 6), - "10.2.6.MariaDB.10.2.6+maria~stretch-log", - (10, 2, 6, "MariaDB", 10, 2, "6+maria~stretch", "log"), - True, - ), - ( - (10, 1, 9), - "10.1.9-MariaDBV1.0R050D002-20170809-1522", - (10, 1, 9, "MariaDB", "V1", "0R050D002", 20170809, 1522), - True, - ), - ]: - dialect = mysql.dialect() - eq_(dialect._parse_server_version(raw_version), version) - dialect.server_version_info = version - eq_(dialect._mariadb_normalized_version_info, expected) - assert dialect._is_mariadb is is_mariadb - - def test_mariadb_check_warning(self): - - for expect_, version in [ - (True, (10, 2, 7, "MariaDB")), - (True, (5, 6, 15, 10, 2, 7, "MariaDB")), - (False, (10, 2, 10, "MariaDB")), - (False, (5, 7, 20)), - (False, (5, 6, 15)), - (True, (10, 2, 6, "MariaDB", 10, 2, "6+maria~stretch", "log")), - ]: - dialect = mysql.dialect() - dialect.server_version_info = version - if expect_: - with expect_warnings( - ".*before 10.2.9 has known issues regarding " - "CHECK constraints" - ): - dialect._warn_for_known_db_issues() - else: + @testing.combinations( + ((10, 2, 7), "10.2.7-MariaDB", (10, 2, 7, "MariaDB"), True), + ( + (10, 2, 7), + "5.6.15.10.2.7-MariaDB", + (5, 6, 15, 10, 2, 7, "MariaDB"), + True, + ), + ((10, 2, 10), "10.2.10-MariaDB", (10, 2, 10, "MariaDB"), True), + ((5, 7, 20), "5.7.20", (5, 7, 20), False), + ((5, 6, 15), "5.6.15", (5, 6, 15), False), + ( + (10, 2, 6), + "10.2.6.MariaDB.10.2.6+maria~stretch-log", + (10, 2, 6, "MariaDB", 10, 2, "6+maria~stretch", "log"), + True, + ), + ( + (10, 1, 9), + "10.1.9-MariaDBV1.0R050D002-20170809-1522", + (10, 1, 9, "MariaDB", "V1", "0R050D002", 20170809, 1522), + True, + ), + ) + def test_mariadb_normalized_version( + self, expected, raw_version, version, is_mariadb + ): + dialect = mysql.dialect() + eq_(dialect._parse_server_version(raw_version), version) + dialect.server_version_info = version + eq_(dialect._mariadb_normalized_version_info, expected) + assert dialect._is_mariadb is is_mariadb + + @testing.combinations( + (True, (10, 2, 7, "MariaDB")), + (True, (5, 6, 15, 10, 2, 7, "MariaDB")), + (False, (10, 2, 10, "MariaDB")), + (False, (5, 7, 20)), + (False, (5, 6, 15)), + (True, (10, 2, 6, "MariaDB", 10, 2, "6+maria~stretch", "log")), + ) + def test_mariadb_check_warning(self, expect_, version): + dialect = mysql.dialect() + dialect.server_version_info = version + if expect_: + with expect_warnings( + ".*before 10.2.9 has known issues regarding " + "CHECK constraints" + ): dialect._warn_for_known_db_issues() + else: + dialect._warn_for_known_db_issues() class RemoveUTCTimestampTest(fixtures.TablesTest): diff --git a/test/dialect/mysql/test_types.py b/test/dialect/mysql/test_types.py index 86e4b13a08..ee626b0826 100644 --- a/test/dialect/mysql/test_types.py +++ b/test/dialect/mysql/test_types.py @@ -33,222 +33,443 @@ from sqlalchemy.testing import is_ from sqlalchemy.util import u -class TypesTest( - fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL -): - "Test MySQL column types" - +class TypeCompileTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = mysql.dialect() - __only_on__ = "mysql" - __backend__ = True - def test_numeric(self): + @testing.combinations( + # column type, args, kwargs, expected ddl + # e.g. Column(Integer(10, unsigned=True)) == + # 'INTEGER(10) UNSIGNED' + (mysql.MSNumeric, [], {}, "NUMERIC"), + (mysql.MSNumeric, [None], {}, "NUMERIC"), + (mysql.MSNumeric, [12], {}, "NUMERIC(12)"), + ( + mysql.MSNumeric, + [12, 4], + {"unsigned": True}, + "NUMERIC(12, 4) UNSIGNED", + ), + ( + mysql.MSNumeric, + [12, 4], + {"zerofill": True}, + "NUMERIC(12, 4) ZEROFILL", + ), + ( + mysql.MSNumeric, + [12, 4], + {"zerofill": True, "unsigned": True}, + "NUMERIC(12, 4) UNSIGNED ZEROFILL", + ), + (mysql.MSDecimal, [], {}, "DECIMAL"), + (mysql.MSDecimal, [None], {}, "DECIMAL"), + (mysql.MSDecimal, [12], {}, "DECIMAL(12)"), + (mysql.MSDecimal, [12, None], {}, "DECIMAL(12)"), + ( + mysql.MSDecimal, + [12, 4], + {"unsigned": True}, + "DECIMAL(12, 4) UNSIGNED", + ), + ( + mysql.MSDecimal, + [12, 4], + {"zerofill": True}, + "DECIMAL(12, 4) ZEROFILL", + ), + ( + mysql.MSDecimal, + [12, 4], + {"zerofill": True, "unsigned": True}, + "DECIMAL(12, 4) UNSIGNED ZEROFILL", + ), + (mysql.MSDouble, [None, None], {}, "DOUBLE"), + ( + mysql.MSDouble, + [12, 4], + {"unsigned": True}, + "DOUBLE(12, 4) UNSIGNED", + ), + ( + mysql.MSDouble, + [12, 4], + {"zerofill": True}, + "DOUBLE(12, 4) ZEROFILL", + ), + ( + mysql.MSDouble, + [12, 4], + {"zerofill": True, "unsigned": True}, + "DOUBLE(12, 4) UNSIGNED ZEROFILL", + ), + (mysql.MSReal, [None, None], {}, "REAL"), + (mysql.MSReal, [12, 4], {"unsigned": True}, "REAL(12, 4) UNSIGNED"), + (mysql.MSReal, [12, 4], {"zerofill": True}, "REAL(12, 4) ZEROFILL"), + ( + mysql.MSReal, + [12, 4], + {"zerofill": True, "unsigned": True}, + "REAL(12, 4) UNSIGNED ZEROFILL", + ), + (mysql.MSFloat, [], {}, "FLOAT"), + (mysql.MSFloat, [None], {}, "FLOAT"), + (mysql.MSFloat, [12], {}, "FLOAT(12)"), + (mysql.MSFloat, [12, 4], {}, "FLOAT(12, 4)"), + (mysql.MSFloat, [12, 4], {"unsigned": True}, "FLOAT(12, 4) UNSIGNED"), + (mysql.MSFloat, [12, 4], {"zerofill": True}, "FLOAT(12, 4) ZEROFILL"), + ( + mysql.MSFloat, + [12, 4], + {"zerofill": True, "unsigned": True}, + "FLOAT(12, 4) UNSIGNED ZEROFILL", + ), + (mysql.MSInteger, [], {}, "INTEGER"), + (mysql.MSInteger, [4], {}, "INTEGER(4)"), + (mysql.MSInteger, [4], {"unsigned": True}, "INTEGER(4) UNSIGNED"), + (mysql.MSInteger, [4], {"zerofill": True}, "INTEGER(4) ZEROFILL"), + ( + mysql.MSInteger, + [4], + {"zerofill": True, "unsigned": True}, + "INTEGER(4) UNSIGNED ZEROFILL", + ), + (mysql.MSBigInteger, [], {}, "BIGINT"), + (mysql.MSBigInteger, [4], {}, "BIGINT(4)"), + (mysql.MSBigInteger, [4], {"unsigned": True}, "BIGINT(4) UNSIGNED"), + (mysql.MSBigInteger, [4], {"zerofill": True}, "BIGINT(4) ZEROFILL"), + ( + mysql.MSBigInteger, + [4], + {"zerofill": True, "unsigned": True}, + "BIGINT(4) UNSIGNED ZEROFILL", + ), + (mysql.MSMediumInteger, [], {}, "MEDIUMINT"), + (mysql.MSMediumInteger, [4], {}, "MEDIUMINT(4)"), + ( + mysql.MSMediumInteger, + [4], + {"unsigned": True}, + "MEDIUMINT(4) UNSIGNED", + ), + ( + mysql.MSMediumInteger, + [4], + {"zerofill": True}, + "MEDIUMINT(4) ZEROFILL", + ), + ( + mysql.MSMediumInteger, + [4], + {"zerofill": True, "unsigned": True}, + "MEDIUMINT(4) UNSIGNED ZEROFILL", + ), + (mysql.MSTinyInteger, [], {}, "TINYINT"), + (mysql.MSTinyInteger, [1], {}, "TINYINT(1)"), + (mysql.MSTinyInteger, [1], {"unsigned": True}, "TINYINT(1) UNSIGNED"), + (mysql.MSTinyInteger, [1], {"zerofill": True}, "TINYINT(1) ZEROFILL"), + ( + mysql.MSTinyInteger, + [1], + {"zerofill": True, "unsigned": True}, + "TINYINT(1) UNSIGNED ZEROFILL", + ), + (mysql.MSSmallInteger, [], {}, "SMALLINT"), + (mysql.MSSmallInteger, [4], {}, "SMALLINT(4)"), + ( + mysql.MSSmallInteger, + [4], + {"unsigned": True}, + "SMALLINT(4) UNSIGNED", + ), + ( + mysql.MSSmallInteger, + [4], + {"zerofill": True}, + "SMALLINT(4) ZEROFILL", + ), + ( + mysql.MSSmallInteger, + [4], + {"zerofill": True, "unsigned": True}, + "SMALLINT(4) UNSIGNED ZEROFILL", + ), + ) + def test_numeric(self, type_, args, kw, res): "Exercise type specification and options for numeric types." - columns = [ - # column type, args, kwargs, expected ddl - # e.g. Column(Integer(10, unsigned=True)) == - # 'INTEGER(10) UNSIGNED' - (mysql.MSNumeric, [], {}, "NUMERIC"), - (mysql.MSNumeric, [None], {}, "NUMERIC"), - (mysql.MSNumeric, [12], {}, "NUMERIC(12)"), - ( - mysql.MSNumeric, - [12, 4], - {"unsigned": True}, - "NUMERIC(12, 4) UNSIGNED", - ), - ( - mysql.MSNumeric, - [12, 4], - {"zerofill": True}, - "NUMERIC(12, 4) ZEROFILL", - ), - ( - mysql.MSNumeric, - [12, 4], - {"zerofill": True, "unsigned": True}, - "NUMERIC(12, 4) UNSIGNED ZEROFILL", - ), - (mysql.MSDecimal, [], {}, "DECIMAL"), - (mysql.MSDecimal, [None], {}, "DECIMAL"), - (mysql.MSDecimal, [12], {}, "DECIMAL(12)"), - (mysql.MSDecimal, [12, None], {}, "DECIMAL(12)"), - ( - mysql.MSDecimal, - [12, 4], - {"unsigned": True}, - "DECIMAL(12, 4) UNSIGNED", - ), - ( - mysql.MSDecimal, - [12, 4], - {"zerofill": True}, - "DECIMAL(12, 4) ZEROFILL", - ), - ( - mysql.MSDecimal, - [12, 4], - {"zerofill": True, "unsigned": True}, - "DECIMAL(12, 4) UNSIGNED ZEROFILL", - ), - (mysql.MSDouble, [None, None], {}, "DOUBLE"), - ( - mysql.MSDouble, - [12, 4], - {"unsigned": True}, - "DOUBLE(12, 4) UNSIGNED", - ), - ( - mysql.MSDouble, - [12, 4], - {"zerofill": True}, - "DOUBLE(12, 4) ZEROFILL", - ), - ( - mysql.MSDouble, - [12, 4], - {"zerofill": True, "unsigned": True}, - "DOUBLE(12, 4) UNSIGNED ZEROFILL", - ), - (mysql.MSReal, [None, None], {}, "REAL"), - ( - mysql.MSReal, - [12, 4], - {"unsigned": True}, - "REAL(12, 4) UNSIGNED", - ), - ( - mysql.MSReal, - [12, 4], - {"zerofill": True}, - "REAL(12, 4) ZEROFILL", - ), - ( - mysql.MSReal, - [12, 4], - {"zerofill": True, "unsigned": True}, - "REAL(12, 4) UNSIGNED ZEROFILL", - ), - (mysql.MSFloat, [], {}, "FLOAT"), - (mysql.MSFloat, [None], {}, "FLOAT"), - (mysql.MSFloat, [12], {}, "FLOAT(12)"), - (mysql.MSFloat, [12, 4], {}, "FLOAT(12, 4)"), - ( - mysql.MSFloat, - [12, 4], - {"unsigned": True}, - "FLOAT(12, 4) UNSIGNED", - ), - ( - mysql.MSFloat, - [12, 4], - {"zerofill": True}, - "FLOAT(12, 4) ZEROFILL", - ), - ( - mysql.MSFloat, - [12, 4], - {"zerofill": True, "unsigned": True}, - "FLOAT(12, 4) UNSIGNED ZEROFILL", - ), - (mysql.MSInteger, [], {}, "INTEGER"), - (mysql.MSInteger, [4], {}, "INTEGER(4)"), - (mysql.MSInteger, [4], {"unsigned": True}, "INTEGER(4) UNSIGNED"), - (mysql.MSInteger, [4], {"zerofill": True}, "INTEGER(4) ZEROFILL"), - ( - mysql.MSInteger, - [4], - {"zerofill": True, "unsigned": True}, - "INTEGER(4) UNSIGNED ZEROFILL", - ), - (mysql.MSBigInteger, [], {}, "BIGINT"), - (mysql.MSBigInteger, [4], {}, "BIGINT(4)"), - ( - mysql.MSBigInteger, - [4], - {"unsigned": True}, - "BIGINT(4) UNSIGNED", - ), - ( - mysql.MSBigInteger, - [4], - {"zerofill": True}, - "BIGINT(4) ZEROFILL", - ), - ( - mysql.MSBigInteger, - [4], - {"zerofill": True, "unsigned": True}, - "BIGINT(4) UNSIGNED ZEROFILL", - ), - (mysql.MSMediumInteger, [], {}, "MEDIUMINT"), - (mysql.MSMediumInteger, [4], {}, "MEDIUMINT(4)"), - ( - mysql.MSMediumInteger, - [4], - {"unsigned": True}, - "MEDIUMINT(4) UNSIGNED", - ), - ( - mysql.MSMediumInteger, - [4], - {"zerofill": True}, - "MEDIUMINT(4) ZEROFILL", - ), - ( - mysql.MSMediumInteger, - [4], - {"zerofill": True, "unsigned": True}, - "MEDIUMINT(4) UNSIGNED ZEROFILL", - ), - (mysql.MSTinyInteger, [], {}, "TINYINT"), - (mysql.MSTinyInteger, [1], {}, "TINYINT(1)"), - ( - mysql.MSTinyInteger, - [1], - {"unsigned": True}, - "TINYINT(1) UNSIGNED", - ), - ( - mysql.MSTinyInteger, - [1], - {"zerofill": True}, - "TINYINT(1) ZEROFILL", - ), - ( - mysql.MSTinyInteger, - [1], - {"zerofill": True, "unsigned": True}, - "TINYINT(1) UNSIGNED ZEROFILL", - ), - (mysql.MSSmallInteger, [], {}, "SMALLINT"), - (mysql.MSSmallInteger, [4], {}, "SMALLINT(4)"), - ( - mysql.MSSmallInteger, - [4], - {"unsigned": True}, - "SMALLINT(4) UNSIGNED", - ), - ( - mysql.MSSmallInteger, - [4], - {"zerofill": True}, - "SMALLINT(4) ZEROFILL", - ), - ( - mysql.MSSmallInteger, - [4], - {"zerofill": True, "unsigned": True}, - "SMALLINT(4) UNSIGNED ZEROFILL", + type_inst = type_(*args, **kw) + self.assert_compile(type_inst, res) + # test that repr() copies out all arguments + self.assert_compile(eval("mysql.%r" % type_inst), res) + + @testing.combinations( + (mysql.MSChar, [1], {}, "CHAR(1)"), + (mysql.NCHAR, [1], {}, "NATIONAL CHAR(1)"), + (mysql.MSChar, [1], {"binary": True}, "CHAR(1) BINARY"), + (mysql.MSChar, [1], {"ascii": True}, "CHAR(1) ASCII"), + (mysql.MSChar, [1], {"unicode": True}, "CHAR(1) UNICODE"), + ( + mysql.MSChar, + [1], + {"ascii": True, "binary": True}, + "CHAR(1) ASCII BINARY", + ), + ( + mysql.MSChar, + [1], + {"unicode": True, "binary": True}, + "CHAR(1) UNICODE BINARY", + ), + (mysql.MSChar, [1], {"charset": "utf8"}, "CHAR(1) CHARACTER SET utf8"), + ( + mysql.MSChar, + [1], + {"charset": "utf8", "binary": True}, + "CHAR(1) CHARACTER SET utf8 BINARY", + ), + ( + mysql.MSChar, + [1], + {"charset": "utf8", "unicode": True}, + "CHAR(1) CHARACTER SET utf8", + ), + ( + mysql.MSChar, + [1], + {"charset": "utf8", "ascii": True}, + "CHAR(1) CHARACTER SET utf8", + ), + ( + mysql.MSChar, + [1], + {"collation": "utf8_bin"}, + "CHAR(1) COLLATE utf8_bin", + ), + ( + mysql.MSChar, + [1], + {"charset": "utf8", "collation": "utf8_bin"}, + "CHAR(1) CHARACTER SET utf8 COLLATE utf8_bin", + ), + ( + mysql.MSChar, + [1], + {"charset": "utf8", "binary": True}, + "CHAR(1) CHARACTER SET utf8 BINARY", + ), + ( + mysql.MSChar, + [1], + {"charset": "utf8", "collation": "utf8_bin", "binary": True}, + "CHAR(1) CHARACTER SET utf8 COLLATE utf8_bin", + ), + (mysql.MSChar, [1], {"national": True}, "NATIONAL CHAR(1)"), + ( + mysql.MSChar, + [1], + {"national": True, "charset": "utf8"}, + "NATIONAL CHAR(1)", + ), + ( + mysql.MSChar, + [1], + {"national": True, "charset": "utf8", "binary": True}, + "NATIONAL CHAR(1) BINARY", + ), + ( + mysql.MSChar, + [1], + {"national": True, "binary": True, "unicode": True}, + "NATIONAL CHAR(1) BINARY", + ), + ( + mysql.MSChar, + [1], + {"national": True, "collation": "utf8_bin"}, + "NATIONAL CHAR(1) COLLATE utf8_bin", + ), + ( + mysql.MSString, + [1], + {"charset": "utf8", "collation": "utf8_bin"}, + "VARCHAR(1) CHARACTER SET utf8 COLLATE utf8_bin", + ), + ( + mysql.MSString, + [1], + {"national": True, "collation": "utf8_bin"}, + "NATIONAL VARCHAR(1) COLLATE utf8_bin", + ), + ( + mysql.MSTinyText, + [], + {"charset": "utf8", "collation": "utf8_bin"}, + "TINYTEXT CHARACTER SET utf8 COLLATE utf8_bin", + ), + ( + mysql.MSMediumText, + [], + {"charset": "utf8", "binary": True}, + "MEDIUMTEXT CHARACTER SET utf8 BINARY", + ), + (mysql.MSLongText, [], {"ascii": True}, "LONGTEXT ASCII"), + ( + mysql.ENUM, + ["foo", "bar"], + {"unicode": True}, + """ENUM('foo','bar') UNICODE""", + ), + (String, [20], {"collation": "utf8"}, "VARCHAR(20) COLLATE utf8"), + ) + @testing.exclude("mysql", "<", (4, 1, 1), "no charset support") + def test_charset(self, type_, args, kw, res): + """Exercise CHARACTER SET and COLLATE-ish options on string types.""" + + type_inst = type_(*args, **kw) + self.assert_compile(type_inst, res) + + @testing.combinations( + (mysql.MSBit(), "BIT"), + (mysql.MSBit(1), "BIT(1)"), + (mysql.MSBit(63), "BIT(63)"), + ) + def test_bit_50(self, type_, expected): + """Exercise BIT types on 5.0+ (not valid for all engine types)""" + + self.assert_compile(type_, expected) + + @testing.combinations( + (BOOLEAN(), "BOOL"), + (Boolean(), "BOOL"), + (mysql.TINYINT(1), "TINYINT(1)"), + (mysql.TINYINT(1, unsigned=True), "TINYINT(1) UNSIGNED"), + ) + def test_boolean_compile(self, type_, expected): + self.assert_compile(type_, expected) + + def test_timestamp_fsp(self): + self.assert_compile(mysql.TIMESTAMP(fsp=5), "TIMESTAMP(5)") + + @testing.combinations( + ([TIMESTAMP], {}, "TIMESTAMP NULL"), + ([mysql.MSTimeStamp], {}, "TIMESTAMP NULL"), + ( + [ + mysql.MSTimeStamp(), + DefaultClause(sql.text("CURRENT_TIMESTAMP")), + ], + {}, + "TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP", + ), + ( + [mysql.MSTimeStamp, DefaultClause(sql.text("CURRENT_TIMESTAMP"))], + {"nullable": False}, + "TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP", + ), + ( + [ + mysql.MSTimeStamp, + DefaultClause(sql.text("'1999-09-09 09:09:09'")), + ], + {"nullable": False}, + "TIMESTAMP NOT NULL DEFAULT '1999-09-09 09:09:09'", + ), + ( + [ + mysql.MSTimeStamp(), + DefaultClause(sql.text("'1999-09-09 09:09:09'")), + ], + {}, + "TIMESTAMP NULL DEFAULT '1999-09-09 09:09:09'", + ), + ( + [ + mysql.MSTimeStamp(), + DefaultClause( + sql.text( + "'1999-09-09 09:09:09' " "ON UPDATE CURRENT_TIMESTAMP" + ) + ), + ], + {}, + "TIMESTAMP NULL DEFAULT '1999-09-09 09:09:09' " + "ON UPDATE CURRENT_TIMESTAMP", + ), + ( + [ + mysql.MSTimeStamp, + DefaultClause( + sql.text( + "'1999-09-09 09:09:09' " "ON UPDATE CURRENT_TIMESTAMP" + ) + ), + ], + {"nullable": False}, + "TIMESTAMP NOT NULL DEFAULT '1999-09-09 09:09:09' " + "ON UPDATE CURRENT_TIMESTAMP", + ), + ( + [ + mysql.MSTimeStamp(), + DefaultClause( + sql.text( + "CURRENT_TIMESTAMP " "ON UPDATE CURRENT_TIMESTAMP" + ) + ), + ], + {}, + "TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP " + "ON UPDATE CURRENT_TIMESTAMP", + ), + ( + [ + mysql.MSTimeStamp, + DefaultClause( + sql.text( + "CURRENT_TIMESTAMP " "ON UPDATE CURRENT_TIMESTAMP" + ) + ), + ], + {"nullable": False}, + "TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP " + "ON UPDATE CURRENT_TIMESTAMP", + ), + ) + def test_timestamp_defaults(self, spec, kw, expected): + """Exercise funky TIMESTAMP default syntax when used in columns.""" + + c = Column("t", *spec, **kw) + Table("t", MetaData(), c) + self.assert_compile(schema.CreateColumn(c), "t %s" % expected) + + def test_datetime_generic(self): + self.assert_compile(mysql.DATETIME(), "DATETIME") + + def test_datetime_fsp(self): + self.assert_compile(mysql.DATETIME(fsp=4), "DATETIME(4)") + + def test_time_generic(self): + """"Exercise TIME.""" + + self.assert_compile(mysql.TIME(), "TIME") + + def test_time_fsp(self): + self.assert_compile(mysql.TIME(fsp=5), "TIME(5)") + + def test_time_result_processor(self): + eq_( + mysql.TIME().result_processor(None, None)( + datetime.timedelta(seconds=35, minutes=517, microseconds=450) ), - ] + datetime.time(8, 37, 35, 450), + ) + + +class TypeRoundTripTest(fixtures.TestBase, AssertsExecutionResults): - for type_, args, kw, res in columns: - type_inst = type_(*args, **kw) - self.assert_compile(type_inst, res) - # test that repr() copies out all arguments - self.assert_compile(eval("mysql.%r" % type_inst), res) + __dialect__ = mysql.dialect() + __only_on__ = "mysql" + __backend__ = True # fixed in mysql-connector as of 2.0.1, # see http://bugs.mysql.com/bug.php?id=73266 @@ -266,157 +487,18 @@ class TypesTest( mysql.DOUBLE(decimal_return_scale=12, asdecimal=True), ), ) - t.create(testing.db) - testing.db.execute( - t.insert(), - scale_value=45.768392065789, - unscale_value=45.768392065789, - ) - result = testing.db.scalar(select([t.c.scale_value])) - eq_(result, decimal.Decimal("45.768392065789")) - - result = testing.db.scalar(select([t.c.unscale_value])) - eq_(result, decimal.Decimal("45.768392065789")) - - @testing.exclude("mysql", "<", (4, 1, 1), "no charset support") - def test_charset(self): - """Exercise CHARACTER SET and COLLATE-ish options on string types.""" - - columns = [ - (mysql.MSChar, [1], {}, "CHAR(1)"), - (mysql.NCHAR, [1], {}, "NATIONAL CHAR(1)"), - (mysql.MSChar, [1], {"binary": True}, "CHAR(1) BINARY"), - (mysql.MSChar, [1], {"ascii": True}, "CHAR(1) ASCII"), - (mysql.MSChar, [1], {"unicode": True}, "CHAR(1) UNICODE"), - ( - mysql.MSChar, - [1], - {"ascii": True, "binary": True}, - "CHAR(1) ASCII BINARY", - ), - ( - mysql.MSChar, - [1], - {"unicode": True, "binary": True}, - "CHAR(1) UNICODE BINARY", - ), - ( - mysql.MSChar, - [1], - {"charset": "utf8"}, - "CHAR(1) CHARACTER SET utf8", - ), - ( - mysql.MSChar, - [1], - {"charset": "utf8", "binary": True}, - "CHAR(1) CHARACTER SET utf8 BINARY", - ), - ( - mysql.MSChar, - [1], - {"charset": "utf8", "unicode": True}, - "CHAR(1) CHARACTER SET utf8", - ), - ( - mysql.MSChar, - [1], - {"charset": "utf8", "ascii": True}, - "CHAR(1) CHARACTER SET utf8", - ), - ( - mysql.MSChar, - [1], - {"collation": "utf8_bin"}, - "CHAR(1) COLLATE utf8_bin", - ), - ( - mysql.MSChar, - [1], - {"charset": "utf8", "collation": "utf8_bin"}, - "CHAR(1) CHARACTER SET utf8 COLLATE utf8_bin", - ), - ( - mysql.MSChar, - [1], - {"charset": "utf8", "binary": True}, - "CHAR(1) CHARACTER SET utf8 BINARY", - ), - ( - mysql.MSChar, - [1], - {"charset": "utf8", "collation": "utf8_bin", "binary": True}, - "CHAR(1) CHARACTER SET utf8 COLLATE utf8_bin", - ), - (mysql.MSChar, [1], {"national": True}, "NATIONAL CHAR(1)"), - ( - mysql.MSChar, - [1], - {"national": True, "charset": "utf8"}, - "NATIONAL CHAR(1)", - ), - ( - mysql.MSChar, - [1], - {"national": True, "charset": "utf8", "binary": True}, - "NATIONAL CHAR(1) BINARY", - ), - ( - mysql.MSChar, - [1], - {"national": True, "binary": True, "unicode": True}, - "NATIONAL CHAR(1) BINARY", - ), - ( - mysql.MSChar, - [1], - {"national": True, "collation": "utf8_bin"}, - "NATIONAL CHAR(1) COLLATE utf8_bin", - ), - ( - mysql.MSString, - [1], - {"charset": "utf8", "collation": "utf8_bin"}, - "VARCHAR(1) CHARACTER SET utf8 COLLATE utf8_bin", - ), - ( - mysql.MSString, - [1], - {"national": True, "collation": "utf8_bin"}, - "NATIONAL VARCHAR(1) COLLATE utf8_bin", - ), - ( - mysql.MSTinyText, - [], - {"charset": "utf8", "collation": "utf8_bin"}, - "TINYTEXT CHARACTER SET utf8 COLLATE utf8_bin", - ), - ( - mysql.MSMediumText, - [], - {"charset": "utf8", "binary": True}, - "MEDIUMTEXT CHARACTER SET utf8 BINARY", - ), - (mysql.MSLongText, [], {"ascii": True}, "LONGTEXT ASCII"), - ( - mysql.ENUM, - ["foo", "bar"], - {"unicode": True}, - """ENUM('foo','bar') UNICODE""", - ), - (String, [20], {"collation": "utf8"}, "VARCHAR(20) COLLATE utf8"), - ] - - for type_, args, kw, res in columns: - type_inst = type_(*args, **kw) - self.assert_compile(type_inst, res) - # test that repr() copies out all arguments - self.assert_compile( - eval("mysql.%r" % type_inst) - if type_ is not String - else eval("%r" % type_inst), - res, + with testing.db.connect() as conn: + t.create(conn) + conn.execute( + t.insert(), + scale_value=45.768392065789, + unscale_value=45.768392065789, ) + result = conn.scalar(select([t.c.scale_value])) + eq_(result, decimal.Decimal("45.768392065789")) + + result = conn.scalar(select([t.c.unscale_value])) + eq_(result, decimal.Decimal("45.768392065789")) @testing.only_if("mysql") @testing.fails_on("mysql+mysqlconnector", "different unicode behavior") @@ -446,22 +528,11 @@ class TypesTest( testing.db.scalar(select([t.c.data])), util.text_type ) - def test_bit_50(self): - """Exercise BIT types on 5.0+ (not valid for all engine types)""" - - for type_, expected in [ - (mysql.MSBit(), "BIT"), - (mysql.MSBit(1), "BIT(1)"), - (mysql.MSBit(63), "BIT(63)"), - ]: - self.assert_compile(type_, expected) - - @testing.exclude("mysql", "<", (5, 0, 5), "a 5.0+ feature") - @testing.provide_metadata - def test_bit_50_roundtrip(self): + @testing.metadata_fixture(ddl="class") + def bit_table(self, metadata): bit_table = Table( "mysql_bits", - self.metadata, + metadata, Column("b1", mysql.MSBit), Column("b2", mysql.MSBit()), Column("b3", mysql.MSBit(), nullable=False), @@ -471,83 +542,103 @@ class TypesTest( Column("b7", mysql.MSBit(63)), Column("b8", mysql.MSBit(64)), ) - self.metadata.create_all() + return bit_table + + i, j, k, l = 255, 2 ** 32 - 1, 2 ** 63 - 1, 2 ** 64 - 1 + + @testing.combinations( + (([0] * 8), None), + ([None, None, 0, None, None, None, None, None], None), + (([1] * 8), None), + ([sql.text("b'1'")] * 8, [1] * 8), + ([0, 0, 0, 0, i, i, i, i], None), + ([0, 0, 0, 0, 0, j, j, j], None), + ([0, 0, 0, 0, 0, 0, k, k], None), + ([0, 0, 0, 0, 0, 0, 0, l], None), + argnames="store, expected", + ) + def test_bit_50_roundtrip(self, bit_table, store, expected): meta2 = MetaData(testing.db) reflected = Table("mysql_bits", meta2, autoload=True) - for table in bit_table, reflected: + with testing.db.connect() as conn: + expected = expected or store + conn.execute(reflected.insert(store)) + row = conn.execute(reflected.select()).first() + eq_(list(row), expected) + conn.execute(reflected.delete()) + + @testing.combinations( + (([0] * 8), None), + ([None, None, 0, None, None, None, None, None], None), + (([1] * 8), None), + ([sql.text("b'1'")] * 8, [1] * 8), + ([0, 0, 0, 0, i, i, i, i], None), + ([0, 0, 0, 0, 0, j, j, j], None), + ([0, 0, 0, 0, 0, 0, k, k], None), + ([0, 0, 0, 0, 0, 0, 0, l], None), + argnames="store, expected", + ) + def test_bit_50_roundtrip_reflected(self, bit_table, store, expected): + meta2 = MetaData() + bit_table = Table("mysql_bits", meta2, autoload_with=testing.db) - def roundtrip(store, expected=None): - expected = expected or store - table.insert(store).execute() - row = table.select().execute().first() - try: - self.assert_(list(row) == expected) - except Exception: - print("Storing %s" % store) - print("Expected %s" % expected) - print("Found %s" % list(row)) - raise - table.delete().execute().close() - - roundtrip([0] * 8) - roundtrip([None, None, 0, None, None, None, None, None]) - roundtrip([1] * 8) - roundtrip([sql.text("b'1'")] * 8, [1] * 8) - - i = 255 - roundtrip([0, 0, 0, 0, i, i, i, i]) - i = 2 ** 32 - 1 - roundtrip([0, 0, 0, 0, 0, i, i, i]) - i = 2 ** 63 - 1 - roundtrip([0, 0, 0, 0, 0, 0, i, i]) - i = 2 ** 64 - 1 - roundtrip([0, 0, 0, 0, 0, 0, 0, i]) - - def test_boolean(self): - for type_, expected in [ - (BOOLEAN(), "BOOL"), - (Boolean(), "BOOL"), - (mysql.TINYINT(1), "TINYINT(1)"), - (mysql.TINYINT(1, unsigned=True), "TINYINT(1) UNSIGNED"), - ]: - self.assert_compile(type_, expected) + with testing.db.connect() as conn: + expected = expected or store + conn.execute(bit_table.insert(store)) + row = conn.execute(bit_table.select()).first() + eq_(list(row), expected) + conn.execute(bit_table.delete()) - @testing.provide_metadata - def test_boolean_roundtrip(self): + @testing.metadata_fixture(ddl="class") + def boolean_table(self, metadata): bool_table = Table( "mysql_bool", - self.metadata, + metadata, Column("b1", BOOLEAN), Column("b2", Boolean), Column("b3", mysql.MSTinyInteger(1)), Column("b4", mysql.MSTinyInteger(1, unsigned=True)), Column("b5", mysql.MSTinyInteger), ) - self.metadata.create_all() - table = bool_table + return bool_table + + @testing.combinations( + ([None, None, None, None, None], None), + ([True, True, 1, 1, 1], None), + ([False, False, 0, 0, 0], None), + ([True, True, True, True, True], [True, True, 1, 1, 1]), + ([False, False, 0, 0, 0], [False, False, 0, 0, 0]), + argnames="store, expected", + ) + def test_boolean_roundtrip(self, boolean_table, store, expected): + table = boolean_table - def roundtrip(store, expected=None): + with testing.db.connect() as conn: expected = expected or store - table.insert(store).execute() - row = table.select().execute().first() - self.assert_(list(row) == expected) + conn.execute(table.insert(store)) + row = conn.execute(table.select()).first() + eq_(list(row), expected) for i, val in enumerate(expected): if isinstance(val, bool): self.assert_(val is row[i]) - table.delete().execute() - - roundtrip([None, None, None, None, None]) - roundtrip([True, True, 1, 1, 1]) - roundtrip([False, False, 0, 0, 0]) - roundtrip([True, True, True, True, True], [True, True, 1, 1, 1]) - roundtrip([False, False, 0, 0, 0], [False, False, 0, 0, 0]) - + conn.execute(table.delete()) + + @testing.combinations( + ([None, None, None, None, None], None), + ([True, True, 1, 1, 1], [True, True, True, True, 1]), + ([False, False, 0, 0, 0], [False, False, False, False, 0]), + ([True, True, True, True, True], [True, True, True, True, 1]), + ([False, False, 0, 0, 0], [False, False, False, False, 0]), + argnames="store, expected", + ) + def test_boolean_roundtrip_reflected(self, boolean_table, store, expected): meta2 = MetaData(testing.db) table = Table("mysql_bool", meta2, autoload=True) eq_(colspec(table.c.b3), "b3 TINYINT(1)") eq_(colspec(table.c.b4), "b4 TINYINT(1) UNSIGNED") + meta2 = MetaData(testing.db) table = Table( "mysql_bool", @@ -560,129 +651,26 @@ class TypesTest( ) eq_(colspec(table.c.b3), "b3 BOOL") eq_(colspec(table.c.b4), "b4 BOOL") - roundtrip([None, None, None, None, None]) - roundtrip([True, True, 1, 1, 1], [True, True, True, True, 1]) - roundtrip([False, False, 0, 0, 0], [False, False, False, False, 0]) - roundtrip([True, True, True, True, True], [True, True, True, True, 1]) - roundtrip([False, False, 0, 0, 0], [False, False, False, False, 0]) - - def test_timestamp_fsp(self): - self.assert_compile(mysql.TIMESTAMP(fsp=5), "TIMESTAMP(5)") - def test_timestamp_defaults(self): - """Exercise funky TIMESTAMP default syntax when used in columns.""" - - columns = [ - ([TIMESTAMP], {}, "TIMESTAMP NULL"), - ([mysql.MSTimeStamp], {}, "TIMESTAMP NULL"), - ( - [ - mysql.MSTimeStamp(), - DefaultClause(sql.text("CURRENT_TIMESTAMP")), - ], - {}, - "TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP", - ), - ( - [ - mysql.MSTimeStamp, - DefaultClause(sql.text("CURRENT_TIMESTAMP")), - ], - {"nullable": False}, - "TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP", - ), - ( - [ - mysql.MSTimeStamp, - DefaultClause(sql.text("'1999-09-09 09:09:09'")), - ], - {"nullable": False}, - "TIMESTAMP NOT NULL DEFAULT '1999-09-09 09:09:09'", - ), - ( - [ - mysql.MSTimeStamp(), - DefaultClause(sql.text("'1999-09-09 09:09:09'")), - ], - {}, - "TIMESTAMP NULL DEFAULT '1999-09-09 09:09:09'", - ), - ( - [ - mysql.MSTimeStamp(), - DefaultClause( - sql.text( - "'1999-09-09 09:09:09' " - "ON UPDATE CURRENT_TIMESTAMP" - ) - ), - ], - {}, - "TIMESTAMP NULL DEFAULT '1999-09-09 09:09:09' " - "ON UPDATE CURRENT_TIMESTAMP", - ), - ( - [ - mysql.MSTimeStamp, - DefaultClause( - sql.text( - "'1999-09-09 09:09:09' " - "ON UPDATE CURRENT_TIMESTAMP" - ) - ), - ], - {"nullable": False}, - "TIMESTAMP NOT NULL DEFAULT '1999-09-09 09:09:09' " - "ON UPDATE CURRENT_TIMESTAMP", - ), - ( - [ - mysql.MSTimeStamp(), - DefaultClause( - sql.text( - "CURRENT_TIMESTAMP " "ON UPDATE CURRENT_TIMESTAMP" - ) - ), - ], - {}, - "TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP " - "ON UPDATE CURRENT_TIMESTAMP", - ), - ( - [ - mysql.MSTimeStamp, - DefaultClause( - sql.text( - "CURRENT_TIMESTAMP " "ON UPDATE CURRENT_TIMESTAMP" - ) - ), - ], - {"nullable": False}, - "TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP " - "ON UPDATE CURRENT_TIMESTAMP", - ), - ] - for spec, kw, expected in columns: - c = Column("t", *spec, **kw) - Table("t", MetaData(), c) - self.assert_compile(schema.CreateColumn(c), "t %s" % expected) - - def test_timestamp_nullable_plain(self): - self._test_timestamp_nullable(TIMESTAMP) - - def test_timestamp_nullable_typedecorator(self): - class MyTime(TypeDecorator): - impl = TIMESTAMP - - self._test_timestamp_nullable(MyTime()) + with testing.db.connect() as conn: + expected = expected or store + conn.execute(table.insert(store)) + row = conn.execute(table.select()).first() + eq_(list(row), expected) + for i, val in enumerate(expected): + if isinstance(val, bool): + self.assert_(val is row[i]) + conn.execute(table.delete()) - def test_timestamp_nullable_variant(self): - t = String().with_variant(TIMESTAMP, "mysql") - self._test_timestamp_nullable(t) + class MyTime(TypeDecorator): + impl = TIMESTAMP + @testing.combinations( + (TIMESTAMP,), (MyTime(),), (String().with_variant(TIMESTAMP, "mysql"),) + ) @testing.requires.mysql_zero_date @testing.provide_metadata - def _test_timestamp_nullable(self, type_): + def test_timestamp_nullable(self, type_): ts_table = Table( "mysql_timestamp", self.metadata, @@ -724,35 +712,19 @@ class TypesTest( [(now, now), (None, now), (None, now)], ) - def test_datetime_generic(self): - self.assert_compile(mysql.DATETIME(), "DATETIME") - - def test_datetime_fsp(self): - self.assert_compile(mysql.DATETIME(fsp=4), "DATETIME(4)") - - def test_time_generic(self): - """"Exercise TIME.""" - - self.assert_compile(mysql.TIME(), "TIME") - - def test_time_fsp(self): - self.assert_compile(mysql.TIME(fsp=5), "TIME(5)") - - def test_time_result_processor(self): - eq_( - mysql.TIME().result_processor(None, None)( - datetime.timedelta(seconds=35, minutes=517, microseconds=450) - ), - datetime.time(8, 37, 35, 450), - ) - @testing.fails_on("mysql+oursql", "TODO: probable OurSQL bug") @testing.provide_metadata def test_time_roundtrip(self): t = Table("mysql_time", self.metadata, Column("t1", mysql.TIME())) - t.create() - t.insert().values(t1=datetime.time(8, 37, 35)).execute() - eq_(select([t.c.t1]).scalar(), datetime.time(8, 37, 35)) + + with testing.db.connect() as conn: + t.create(conn) + + conn.execute(t.insert().values(t1=datetime.time(8, 37, 35))) + eq_( + conn.execute(select([t.c.t1])).scalar(), + datetime.time(8, 37, 35), + ) @testing.provide_metadata def test_year(self): @@ -773,12 +745,13 @@ class TypesTest( reflected = Table("mysql_year", MetaData(testing.db), autoload=True) for table in year_table, reflected: - table.insert(["1950", "50", None, 1950]).execute() - row = table.select().execute().first() - eq_(list(row), [1950, 2050, None, 1950]) - table.delete().execute() - self.assert_(colspec(table.c.y1).startswith("y1 YEAR")) - eq_(colspec(table.c.y5), "y5 YEAR(4)") + with testing.db.connect() as conn: + conn.execute(table.insert(["1950", "50", None, 1950])) + row = conn.execute(table.select()).first() + eq_(list(row), [1950, 2050, None, 1950]) + conn.execute(table.delete()) + self.assert_(colspec(table.c.y1).startswith("y1 YEAR")) + eq_(colspec(table.c.y5), "y5 YEAR(4)") class JSONTest(fixtures.TestBase): diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index c557fe4128..e664ba50ec 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -76,45 +76,45 @@ class DialectTypesTest(fixtures.TestBase, AssertsCompiledSQL): def test_long(self): self.assert_compile(oracle.LONG(), "LONG") - def test_type_adapt(self): + @testing.combinations( + (Date(), cx_oracle._OracleDate), + (oracle.OracleRaw(), cx_oracle._OracleRaw), + (String(), String), + (VARCHAR(), cx_oracle._OracleString), + (DATE(), cx_oracle._OracleDate), + (oracle.DATE(), oracle.DATE), + (String(50), cx_oracle._OracleString), + (Unicode(), cx_oracle._OracleUnicodeStringCHAR), + (Text(), cx_oracle._OracleText), + (UnicodeText(), cx_oracle._OracleUnicodeTextCLOB), + (CHAR(), cx_oracle._OracleChar), + (NCHAR(), cx_oracle._OracleNChar), + (NVARCHAR(), cx_oracle._OracleUnicodeStringNCHAR), + (oracle.RAW(50), cx_oracle._OracleRaw), + ) + def test_type_adapt(self, start, test): dialect = cx_oracle.dialect() - for start, test in [ - (Date(), cx_oracle._OracleDate), - (oracle.OracleRaw(), cx_oracle._OracleRaw), - (String(), String), - (VARCHAR(), cx_oracle._OracleString), - (DATE(), cx_oracle._OracleDate), - (oracle.DATE(), oracle.DATE), - (String(50), cx_oracle._OracleString), - (Unicode(), cx_oracle._OracleUnicodeStringCHAR), - (Text(), cx_oracle._OracleText), - (UnicodeText(), cx_oracle._OracleUnicodeTextCLOB), - (CHAR(), cx_oracle._OracleChar), - (NCHAR(), cx_oracle._OracleNChar), - (NVARCHAR(), cx_oracle._OracleUnicodeStringNCHAR), - (oracle.RAW(50), cx_oracle._OracleRaw), - ]: - assert isinstance( - start.dialect_impl(dialect), test - ), "wanted %r got %r" % (test, start.dialect_impl(dialect)) - - def test_type_adapt_nchar(self): + assert isinstance( + start.dialect_impl(dialect), test + ), "wanted %r got %r" % (test, start.dialect_impl(dialect)) + + @testing.combinations( + (String(), String), + (VARCHAR(), cx_oracle._OracleString), + (String(50), cx_oracle._OracleString), + (Unicode(), cx_oracle._OracleUnicodeStringNCHAR), + (Text(), cx_oracle._OracleText), + (UnicodeText(), cx_oracle._OracleUnicodeTextNCLOB), + (NCHAR(), cx_oracle._OracleNChar), + (NVARCHAR(), cx_oracle._OracleUnicodeStringNCHAR), + ) + def test_type_adapt_nchar(self, start, test): dialect = cx_oracle.dialect(use_nchar_for_unicode=True) - for start, test in [ - (String(), String), - (VARCHAR(), cx_oracle._OracleString), - (String(50), cx_oracle._OracleString), - (Unicode(), cx_oracle._OracleUnicodeStringNCHAR), - (Text(), cx_oracle._OracleText), - (UnicodeText(), cx_oracle._OracleUnicodeTextNCLOB), - (NCHAR(), cx_oracle._OracleNChar), - (NVARCHAR(), cx_oracle._OracleUnicodeStringNCHAR), - ]: - assert isinstance( - start.dialect_impl(dialect), test - ), "wanted %r got %r" % (test, start.dialect_impl(dialect)) + assert isinstance( + start.dialect_impl(dialect), test + ), "wanted %r got %r" % (test, start.dialect_impl(dialect)) def test_raw_compile(self): self.assert_compile(oracle.RAW(), "RAW") @@ -130,53 +130,53 @@ class DialectTypesTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(NVARCHAR(50), "NVARCHAR2(50)") self.assert_compile(CHAR(50), "CHAR(50)") - def test_varchar_types(self): + @testing.combinations( + (String(50), "VARCHAR2(50 CHAR)"), + (Unicode(50), "VARCHAR2(50 CHAR)"), + (NVARCHAR(50), "NVARCHAR2(50)"), + (VARCHAR(50), "VARCHAR(50 CHAR)"), + (oracle.NVARCHAR2(50), "NVARCHAR2(50)"), + (oracle.VARCHAR2(50), "VARCHAR2(50 CHAR)"), + (String(), "VARCHAR2"), + (Unicode(), "VARCHAR2"), + (NVARCHAR(), "NVARCHAR2"), + (VARCHAR(), "VARCHAR"), + (oracle.NVARCHAR2(), "NVARCHAR2"), + (oracle.VARCHAR2(), "VARCHAR2"), + ) + def test_varchar_types(self, typ, exp): dialect = oracle.dialect() - for typ, exp in [ - (String(50), "VARCHAR2(50 CHAR)"), - (Unicode(50), "VARCHAR2(50 CHAR)"), - (NVARCHAR(50), "NVARCHAR2(50)"), - (VARCHAR(50), "VARCHAR(50 CHAR)"), - (oracle.NVARCHAR2(50), "NVARCHAR2(50)"), - (oracle.VARCHAR2(50), "VARCHAR2(50 CHAR)"), - (String(), "VARCHAR2"), - (Unicode(), "VARCHAR2"), - (NVARCHAR(), "NVARCHAR2"), - (VARCHAR(), "VARCHAR"), - (oracle.NVARCHAR2(), "NVARCHAR2"), - (oracle.VARCHAR2(), "VARCHAR2"), - ]: - self.assert_compile(typ, exp, dialect=dialect) - - def test_varchar_use_nchar_types(self): + self.assert_compile(typ, exp, dialect=dialect) + + @testing.combinations( + (String(50), "VARCHAR2(50 CHAR)"), + (Unicode(50), "NVARCHAR2(50)"), + (NVARCHAR(50), "NVARCHAR2(50)"), + (VARCHAR(50), "VARCHAR(50 CHAR)"), + (oracle.NVARCHAR2(50), "NVARCHAR2(50)"), + (oracle.VARCHAR2(50), "VARCHAR2(50 CHAR)"), + (String(), "VARCHAR2"), + (Unicode(), "NVARCHAR2"), + (NVARCHAR(), "NVARCHAR2"), + (VARCHAR(), "VARCHAR"), + (oracle.NVARCHAR2(), "NVARCHAR2"), + (oracle.VARCHAR2(), "VARCHAR2"), + ) + def test_varchar_use_nchar_types(self, typ, exp): dialect = oracle.dialect(use_nchar_for_unicode=True) - for typ, exp in [ - (String(50), "VARCHAR2(50 CHAR)"), - (Unicode(50), "NVARCHAR2(50)"), - (NVARCHAR(50), "NVARCHAR2(50)"), - (VARCHAR(50), "VARCHAR(50 CHAR)"), - (oracle.NVARCHAR2(50), "NVARCHAR2(50)"), - (oracle.VARCHAR2(50), "VARCHAR2(50 CHAR)"), - (String(), "VARCHAR2"), - (Unicode(), "NVARCHAR2"), - (NVARCHAR(), "NVARCHAR2"), - (VARCHAR(), "VARCHAR"), - (oracle.NVARCHAR2(), "NVARCHAR2"), - (oracle.VARCHAR2(), "VARCHAR2"), - ]: - self.assert_compile(typ, exp, dialect=dialect) + self.assert_compile(typ, exp, dialect=dialect) - def test_interval(self): - for type_, expected in [ - (oracle.INTERVAL(), "INTERVAL DAY TO SECOND"), - (oracle.INTERVAL(day_precision=3), "INTERVAL DAY(3) TO SECOND"), - (oracle.INTERVAL(second_precision=5), "INTERVAL DAY TO SECOND(5)"), - ( - oracle.INTERVAL(day_precision=2, second_precision=5), - "INTERVAL DAY(2) TO SECOND(5)", - ), - ]: - self.assert_compile(type_, expected) + @testing.combinations( + (oracle.INTERVAL(), "INTERVAL DAY TO SECOND"), + (oracle.INTERVAL(day_precision=3), "INTERVAL DAY(3) TO SECOND"), + (oracle.INTERVAL(second_precision=5), "INTERVAL DAY TO SECOND(5)"), + ( + oracle.INTERVAL(day_precision=2, second_precision=5), + "INTERVAL DAY(2) TO SECOND(5)", + ), + ) + def test_interval(self, type_, expected): + self.assert_compile(type_, expected) class TypesTest(fixtures.TestBase): @@ -184,14 +184,9 @@ class TypesTest(fixtures.TestBase): __dialect__ = oracle.OracleDialect() __backend__ = True - def test_fixed_char(self): - self._test_fixed_char(CHAR) - - def test_fixed_nchar(self): - self._test_fixed_char(NCHAR) - + @testing.combinations((CHAR,), (NCHAR,)) @testing.provide_metadata - def _test_fixed_char(self, char_type): + def test_fixed_char(self, char_type): m = self.metadata t = Table( "t1", @@ -682,52 +677,46 @@ class TypesTest(fixtures.TestBase): value = testing.db.scalar("SELECT 5.66 FROM DUAL") assert isinstance(value, decimal.Decimal) + @testing.combinations( + ( + "Max 32-bit Number", + "SELECT CAST(2147483647 AS NUMBER(19,0)) FROM dual", + ), + ( + "Min 32-bit Number", + "SELECT CAST(-2147483648 AS NUMBER(19,0)) FROM dual", + ), + ( + "32-bit Integer Overflow", + "SELECT CAST(2147483648 AS NUMBER(19,0)) FROM dual", + ), + ( + "32-bit Integer Underflow", + "SELECT CAST(-2147483649 AS NUMBER(19,0)) FROM dual", + ), + ( + "Max Number with Precision 19", + "SELECT CAST(9999999999999999999 AS NUMBER(19,0)) FROM dual", + ), + ( + "Min Number with Precision 19", + "SELECT CAST(-9999999999999999999 AS NUMBER(19,0)) FROM dual", + ), + ) @testing.only_on("oracle+cx_oracle", "cx_oracle-specific feature") - def test_raw_numerics(self): - query_cases = [ - ( - "Max 32-bit Number", - "SELECT CAST(2147483647 AS NUMBER(19,0)) FROM dual", - ), - ( - "Min 32-bit Number", - "SELECT CAST(-2147483648 AS NUMBER(19,0)) FROM dual", - ), - ( - "32-bit Integer Overflow", - "SELECT CAST(2147483648 AS NUMBER(19,0)) FROM dual", - ), - ( - "32-bit Integer Underflow", - "SELECT CAST(-2147483649 AS NUMBER(19,0)) FROM dual", - ), - ( - "Max Number with Precision 19", - "SELECT CAST(9999999999999999999 AS NUMBER(19,0)) FROM dual", - ), - ( - "Min Number with Precision 19", - "SELECT CAST(-9999999999999999999 AS NUMBER(19,0)) FROM dual", - ), - ] - + def test_raw_numerics(self, title, stmt): with testing.db.connect() as conn: - for title, stmt in query_cases: - # get a brand new connection that definitely is not - # in the pool to avoid any outputtypehandlers - cx_oracle_raw = testing.db.pool._creator() - cursor = cx_oracle_raw.cursor() - cursor.execute(stmt) - cx_oracle_result = cursor.fetchone()[0] - cursor.close() - - sqla_result = conn.scalar(stmt) - - print( - "%s cx_oracle=%s sqlalchemy=%s" - % (title, cx_oracle_result, sqla_result) - ) - eq_(sqla_result, cx_oracle_result) + # get a brand new connection that definitely is not + # in the pool to avoid any outputtypehandlers + cx_oracle_raw = testing.db.pool._creator() + cursor = cx_oracle_raw.cursor() + cursor.execute(stmt) + cx_oracle_result = cursor.fetchone()[0] + cursor.close() + + sqla_result = conn.scalar(stmt) + + eq_(sqla_result, cx_oracle_result) @testing.only_on("oracle+cx_oracle", "cx_oracle-specific feature") @testing.fails_if( @@ -1111,10 +1100,32 @@ class SetInputSizesTest(fixtures.TestBase): __only_on__ = "oracle+cx_oracle" __backend__ = True + @testing.combinations( + (SmallInteger, 25, int, False), + (Integer, 25, int, False), + (Numeric(10, 8), decimal.Decimal("25.34534"), None, False), + (Float(15), 25.34534, None, False), + (oracle.BINARY_DOUBLE, 25.34534, "NATIVE_FLOAT", False), + (oracle.BINARY_FLOAT, 25.34534, "NATIVE_FLOAT", False), + (oracle.DOUBLE_PRECISION, 25.34534, None, False), + (Unicode(30), u("test"), "NCHAR", True), + (UnicodeText(), u("test"), "NCLOB", True), + (Unicode(30), u("test"), None, False), + (UnicodeText(), u("test"), "CLOB", False), + (String(30), "test", None, False), + (CHAR(30), "test", "FIXED_CHAR", False), + (NCHAR(30), u("test"), "FIXED_NCHAR", False), + (oracle.LONG(), "test", None, False), + ) @testing.provide_metadata - def _test_setinputsizes( - self, datatype, value, sis_value, set_nchar_flag=False + def test_setinputsizes( + self, datatype, value, sis_value_text, set_nchar_flag ): + if isinstance(sis_value_text, str): + sis_value = getattr(testing.db.dialect.dbapi, sis_value_text) + else: + sis_value = sis_value_text + class TestTypeDec(TypeDecorator): impl = NullType() @@ -1176,77 +1187,6 @@ class SetInputSizesTest(fixtures.TestBase): [mock.call.setinputsizes()], ) - def test_smallint_setinputsizes(self): - self._test_setinputsizes(SmallInteger, 25, int) - - def test_int_setinputsizes(self): - self._test_setinputsizes(Integer, 25, int) - - def test_numeric_setinputsizes(self): - self._test_setinputsizes( - Numeric(10, 8), decimal.Decimal("25.34534"), None - ) - - def test_float_setinputsizes(self): - self._test_setinputsizes(Float(15), 25.34534, None) - - def test_binary_double_setinputsizes(self): - self._test_setinputsizes( - oracle.BINARY_DOUBLE, - 25.34534, - testing.db.dialect.dbapi.NATIVE_FLOAT, - ) - - def test_binary_float_setinputsizes(self): - self._test_setinputsizes( - oracle.BINARY_FLOAT, - 25.34534, - testing.db.dialect.dbapi.NATIVE_FLOAT, - ) - - def test_double_precision_setinputsizes(self): - self._test_setinputsizes(oracle.DOUBLE_PRECISION, 25.34534, None) - - def test_unicode_nchar_mode(self): - self._test_setinputsizes( - Unicode(30), - u("test"), - testing.db.dialect.dbapi.NCHAR, - set_nchar_flag=True, - ) - - def test_unicodetext_nchar_mode(self): - self._test_setinputsizes( - UnicodeText(), - u("test"), - testing.db.dialect.dbapi.NCLOB, - set_nchar_flag=True, - ) - - def test_unicode(self): - self._test_setinputsizes(Unicode(30), u("test"), None) - - def test_unicodetext(self): - self._test_setinputsizes( - UnicodeText(), u("test"), testing.db.dialect.dbapi.CLOB - ) - - def test_string(self): - self._test_setinputsizes(String(30), "test", None) - - def test_char(self): - self._test_setinputsizes( - CHAR(30), "test", testing.db.dialect.dbapi.FIXED_CHAR - ) - - def test_nchar(self): - self._test_setinputsizes( - NCHAR(30), u("test"), testing.db.dialect.dbapi.FIXED_NCHAR - ) - - def test_long(self): - self._test_setinputsizes(oracle.LONG(), "test", None) - def test_event_no_native_float(self): def _remove_type(inputsizes, cursor, statement, parameters, context): for param, dbapitype in list(inputsizes.items()): @@ -1255,6 +1195,6 @@ class SetInputSizesTest(fixtures.TestBase): event.listen(testing.db, "do_setinputsizes", _remove_type) try: - self._test_setinputsizes(oracle.BINARY_FLOAT, 25.34534, None) + self.test_setinputsizes(oracle.BINARY_FLOAT, 25.34534, None, False) finally: event.remove(testing.db, "do_setinputsizes", _remove_type) diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 557b916222..1eb8677bfd 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1,6 +1,7 @@ # coding: utf-8 import datetime import decimal +import uuid import sqlalchemy as sa from sqlalchemy import any_ @@ -32,6 +33,7 @@ from sqlalchemy import Unicode from sqlalchemy import util from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import array +from sqlalchemy.dialects.postgresql import base from sqlalchemy.dialects.postgresql import DATERANGE from sqlalchemy.dialects.postgresql import HSTORE from sqlalchemy.dialects.postgresql import hstore @@ -885,37 +887,35 @@ class TimezoneTest(fixtures.TestBase): assert row[0] >= somedate -class TimePrecisionTest(fixtures.TestBase, AssertsCompiledSQL): +class TimePrecisionCompileTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = postgresql.dialect() + + @testing.combinations( + (postgresql.TIME(), "TIME WITHOUT TIME ZONE"), + (postgresql.TIME(precision=5), "TIME(5) WITHOUT TIME ZONE"), + ( + postgresql.TIME(timezone=True, precision=5), + "TIME(5) WITH TIME ZONE", + ), + (postgresql.TIMESTAMP(), "TIMESTAMP WITHOUT TIME ZONE"), + (postgresql.TIMESTAMP(precision=5), "TIMESTAMP(5) WITHOUT TIME ZONE"), + ( + postgresql.TIMESTAMP(timezone=True, precision=5), + "TIMESTAMP(5) WITH TIME ZONE", + ), + (postgresql.TIME(precision=0), "TIME(0) WITHOUT TIME ZONE"), + (postgresql.TIMESTAMP(precision=0), "TIMESTAMP(0) WITHOUT TIME ZONE"), + ) + def test_compile(self, type_, expected): + self.assert_compile(type_, expected) + + +class TimePrecisionTest(fixtures.TestBase): __dialect__ = postgresql.dialect() __prefer__ = "postgresql" __backend__ = True - def test_compile(self): - for type_, expected in [ - (postgresql.TIME(), "TIME WITHOUT TIME ZONE"), - (postgresql.TIME(precision=5), "TIME(5) WITHOUT TIME ZONE"), - ( - postgresql.TIME(timezone=True, precision=5), - "TIME(5) WITH TIME ZONE", - ), - (postgresql.TIMESTAMP(), "TIMESTAMP WITHOUT TIME ZONE"), - ( - postgresql.TIMESTAMP(precision=5), - "TIMESTAMP(5) WITHOUT TIME ZONE", - ), - ( - postgresql.TIMESTAMP(timezone=True, precision=5), - "TIMESTAMP(5) WITH TIME ZONE", - ), - (postgresql.TIME(precision=0), "TIME(0) WITHOUT TIME ZONE"), - ( - postgresql.TIMESTAMP(precision=0), - "TIMESTAMP(0) WITHOUT TIME ZONE", - ), - ]: - self.assert_compile(type_, expected) - @testing.only_on("postgresql", "DB specific feature") @testing.provide_metadata def test_reflection(self): @@ -1608,7 +1608,8 @@ class PGArrayRoundTripTest( ): ARRAY = postgresql.ARRAY - def _test_undim_array_contains_typed_exec(self, struct): + @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),)) + def test_undim_array_contains_typed_exec(self, struct): arrtable = self.tables.arrtable self._fixture_456(arrtable) eq_( @@ -1620,18 +1621,8 @@ class PGArrayRoundTripTest( [4, 5, 6], ) - def test_undim_array_contains_set_exec(self): - self._test_undim_array_contains_typed_exec(set) - - def test_undim_array_contains_list_exec(self): - self._test_undim_array_contains_typed_exec(list) - - def test_undim_array_contains_generator_exec(self): - self._test_undim_array_contains_typed_exec( - lambda elem: (x for x in elem) - ) - - def _test_dim_array_contains_typed_exec(self, struct): + @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),)) + def test_dim_array_contains_typed_exec(self, struct): dim_arrtable = self.tables.dim_arrtable self._fixture_456(dim_arrtable) eq_( @@ -1643,17 +1634,6 @@ class PGArrayRoundTripTest( [4, 5, 6], ) - def test_dim_array_contains_set_exec(self): - self._test_dim_array_contains_typed_exec(set) - - def test_dim_array_contains_list_exec(self): - self._test_dim_array_contains_typed_exec(list) - - def test_dim_array_contains_generator_exec(self): - self._test_dim_array_contains_typed_exec( - lambda elem: (x for x in elem) - ) - def test_array_contained_by_exec(self): arrtable = self.tables.arrtable with testing.db.connect() as conn: @@ -1697,7 +1677,28 @@ class HashableFlagORMTest(fixtures.TestBase): __only_on__ = "postgresql" - def _test(self, type_, data): + @testing.combinations( + ( + "ARRAY", + postgresql.ARRAY(Text()), + [["a", "b", "c"], ["d", "e", "f"]], + ), + ( + "JSON", + postgresql.JSON(), + [ + {"a": "1", "b": "2", "c": "3"}, + { + "d": "4", + "e": {"e1": "5", "e2": "6"}, + "f": {"f1": [9, 10, 11]}, + }, + ], + ), + id_="iaa", + ) + @testing.provide_metadata + def test_hashable_flag(self, type_, data): Base = declarative_base(metadata=self.metadata) class A(Base): @@ -1718,38 +1719,16 @@ class HashableFlagORMTest(fixtures.TestBase): list(enumerate(data, 1)), ) - @testing.provide_metadata - def test_array(self): - self._test( - postgresql.ARRAY(Text()), [["a", "b", "c"], ["d", "e", "f"]] - ) - @testing.requires.hstore - @testing.provide_metadata def test_hstore(self): - self._test( + self.test_hashable_flag( postgresql.HSTORE(), [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}], ) - @testing.provide_metadata - def test_json(self): - self._test( - postgresql.JSON(), - [ - {"a": "1", "b": "2", "c": "3"}, - { - "d": "4", - "e": {"e1": "5", "e2": "6"}, - "f": {"f1": [9, 10, 11]}, - }, - ], - ) - @testing.requires.postgresql_jsonb - @testing.provide_metadata def test_jsonb(self): - self._test( + self.test_hashable_flag( postgresql.JSONB(), [ {"a": "1", "b": "2", "c": "3"}, @@ -1795,17 +1774,28 @@ class TimestampTest(fixtures.TestBase, AssertsExecutionResults): assert isinstance(expr.type, postgresql.INTERVAL) -class SpecialTypesTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): +class SpecialTypesCompileTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = "postgresql" + + @testing.combinations( + (postgresql.BIT(), "BIT(1)"), + (postgresql.BIT(5), "BIT(5)"), + (postgresql.BIT(varying=True), "BIT VARYING"), + (postgresql.BIT(5, varying=True), "BIT VARYING(5)"), + ) + def test_bit_compile(self, type_, expected): + self.assert_compile(type_, expected) + + +class SpecialTypesTest(fixtures.TablesTest, ComparesTables): """test DDL and reflection of PG-specific types """ __only_on__ = ("postgresql >= 8.3.0",) __backend__ = True - @classmethod - def setup_class(cls): - global metadata, table - metadata = MetaData(testing.db) + @testing.metadata_fixture() + def special_types_table(self, metadata): # create these types so that we can issue # special SQL92 INTERVAL syntax @@ -1835,36 +1825,22 @@ class SpecialTypesTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): Column("tsvector_document", postgresql.TSVECTOR), ) - metadata.create_all() + return table + def test_reflection(self, special_types_table): # cheat so that the "strict type check" # works - table.c.year_interval.type = postgresql.INTERVAL() - table.c.month_interval.type = postgresql.INTERVAL() - - @classmethod - def teardown_class(cls): - metadata.drop_all() + special_types_table.c.year_interval.type = postgresql.INTERVAL() + special_types_table.c.month_interval.type = postgresql.INTERVAL() - def test_reflection(self): m = MetaData(testing.db) t = Table("sometable", m, autoload=True) - self.assert_tables_equal(table, t, strict_types=True) + self.assert_tables_equal(special_types_table, t, strict_types=True) assert t.c.plain_interval.type.precision is None assert t.c.precision_interval.type.precision == 3 assert t.c.bitstring.type.length == 4 - def test_bit_compile(self): - pairs = [ - (postgresql.BIT(), "BIT(1)"), - (postgresql.BIT(5), "BIT(5)"), - (postgresql.BIT(varying=True), "BIT VARYING"), - (postgresql.BIT(5, varying=True), "BIT VARYING(5)"), - ] - for type_, expected in pairs: - self.assert_compile(type_, expected) - @testing.provide_metadata def test_tsvector_round_trip(self): t = Table("t1", self.metadata, Column("data", postgresql.TSVECTOR)) @@ -1910,86 +1886,51 @@ class UUIDTest(fixtures.TestBase): __only_on__ = "postgresql >= 8.3" __backend__ = True - @testing.fails_on( - "postgresql+zxjdbc", - 'column "data" is of type uuid but expression ' - "is of type character varying", - ) - def test_uuid_string(self): - import uuid - - self._test_round_trip( - Table( - "utable", - MetaData(), - Column("data", postgresql.UUID(as_uuid=False)), - ), + @testing.combinations( + ( + "not_as_uuid", + postgresql.UUID(as_uuid=False), str(uuid.uuid4()), str(uuid.uuid4()), - ) - - @testing.fails_on( - "postgresql+zxjdbc", - 'column "data" is of type uuid but expression is ' - "of type character varying", + ), + ("as_uuid", postgresql.UUID(as_uuid=True), uuid.uuid4(), uuid.uuid4()), + id_="iaaa", ) - def test_uuid_uuid(self): - import uuid - - self._test_round_trip( - Table( - "utable", - MetaData(), - Column("data", postgresql.UUID(as_uuid=True)), - ), - uuid.uuid4(), - uuid.uuid4(), - ) + def test_round_trip(self, datatype, value1, value2): - @testing.fails_on( - "postgresql+zxjdbc", - 'column "data" is of type uuid[] but ' - "expression is of type character varying", - ) - @testing.fails_on("postgresql+pg8000", "No support for UUID with ARRAY") - def test_uuid_array(self): - import uuid - - self._test_round_trip( - Table( - "utable", - MetaData(), - Column( - "data", postgresql.ARRAY(postgresql.UUID(as_uuid=True)) - ), - ), + utable = Table("utable", MetaData(), Column("data", datatype)) + + with testing.db.connect() as conn: + conn.begin() + utable.create(conn) + conn.execute(utable.insert(), {"data": value1}) + conn.execute(utable.insert(), {"data": value2}) + r = conn.execute( + select([utable.c.data]).where(utable.c.data != value1) + ) + eq_(r.fetchone()[0], value2) + eq_(r.fetchone(), None) + + @testing.combinations( + ( + "as_uuid", + postgresql.ARRAY(postgresql.UUID(as_uuid=True)), [uuid.uuid4(), uuid.uuid4()], [uuid.uuid4(), uuid.uuid4()], - ) - - @testing.fails_on( - "postgresql+zxjdbc", - 'column "data" is of type uuid[] but ' - "expression is of type character varying", - ) - @testing.fails_on("postgresql+pg8000", "No support for UUID with ARRAY") - def test_uuid_string_array(self): - import uuid - - self._test_round_trip( - Table( - "utable", - MetaData(), - Column( - "data", postgresql.ARRAY(postgresql.UUID(as_uuid=False)) - ), - ), + ), + ( + "not_as_uuid", + postgresql.ARRAY(postgresql.UUID(as_uuid=False)), [str(uuid.uuid4()), str(uuid.uuid4())], [str(uuid.uuid4()), str(uuid.uuid4())], - ) + ), + id_="iaaa", + ) + @testing.fails_on("postgresql+pg8000", "No support for UUID with ARRAY") + def test_uuid_array(self, datatype, value1, value2): + self.test_round_trip(datatype, value1, value2) def test_no_uuid_available(self): - from sqlalchemy.dialects.postgresql import base uuid_type = base._python_UUID base._python_UUID = None @@ -1998,26 +1939,6 @@ class UUIDTest(fixtures.TestBase): finally: base._python_UUID = uuid_type - def setup(self): - self.conn = testing.db.connect() - self.conn.begin() - - def teardown(self): - self.conn.close() - - def _test_round_trip(self, utable, value1, value2, exp_value2=None): - utable.create(self.conn) - self.conn.execute(utable.insert(), {"data": value1}) - self.conn.execute(utable.insert(), {"data": value2}) - r = self.conn.execute( - select([utable.c.data]).where(utable.c.data != value1) - ) - if exp_value2: - eq_(r.fetchone()[0], exp_value2) - else: - eq_(r.fetchone()[0], value2) - eq_(r.fetchone(), None) - class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): __dialect__ = "postgresql" @@ -2040,13 +1961,6 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): "WHERE %s" % expected, ) - def _test_cols(self, colclause, expected, from_=True): - stmt = select([colclause]) - self.assert_compile( - stmt, - ("SELECT %s" + (" FROM test_table" if from_ else "")) % expected, - ) - def test_bind_serialize_default(self): dialect = postgresql.dialect() @@ -2184,60 +2098,48 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): "(test_table.hash -> %(hash_1)s) IS NULL", ) - def test_cols_get(self): - self._test_cols( - self.hashcol["foo"], + @testing.combinations( + ( + lambda self: self.hashcol["foo"], "test_table.hash -> %(hash_1)s AS anon_1", True, - ) - - def test_cols_delete_single_key(self): - self._test_cols( - self.hashcol.delete("foo"), + ), + ( + lambda self: self.hashcol.delete("foo"), "delete(test_table.hash, %(delete_2)s) AS delete_1", True, - ) - - def test_cols_delete_array_of_keys(self): - self._test_cols( - self.hashcol.delete(postgresql.array(["foo", "bar"])), + ), + ( + lambda self: self.hashcol.delete(postgresql.array(["foo", "bar"])), ( "delete(test_table.hash, ARRAY[%(param_1)s, %(param_2)s]) " "AS delete_1" ), True, - ) - - def test_cols_delete_matching_pairs(self): - self._test_cols( - self.hashcol.delete(hstore("1", "2")), + ), + ( + lambda self: self.hashcol.delete(hstore("1", "2")), ( "delete(test_table.hash, hstore(%(hstore_1)s, %(hstore_2)s)) " "AS delete_1" ), True, - ) - - def test_cols_slice(self): - self._test_cols( - self.hashcol.slice(postgresql.array(["1", "2"])), + ), + ( + lambda self: self.hashcol.slice(postgresql.array(["1", "2"])), ( "slice(test_table.hash, ARRAY[%(param_1)s, %(param_2)s]) " "AS slice_1" ), True, - ) - - def test_cols_hstore_pair_text(self): - self._test_cols( - hstore("foo", "3")["foo"], + ), + ( + lambda self: hstore("foo", "3")["foo"], "hstore(%(hstore_1)s, %(hstore_2)s) -> %(hstore_3)s AS anon_1", False, - ) - - def test_cols_hstore_pair_array(self): - self._test_cols( - hstore( + ), + ( + lambda self: hstore( postgresql.array(["1", "2"]), postgresql.array(["3", None]) )["1"], ( @@ -2245,72 +2147,68 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): "ARRAY[%(param_3)s, NULL]) -> %(hstore_1)s AS anon_1" ), False, - ) - - def test_cols_hstore_single_array(self): - self._test_cols( - hstore(postgresql.array(["1", "2", "3", None]))["3"], + ), + ( + lambda self: hstore(postgresql.array(["1", "2", "3", None]))["3"], ( "hstore(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s, NULL]) " "-> %(hstore_1)s AS anon_1" ), False, - ) - - def test_cols_concat(self): - self._test_cols( - self.hashcol.concat(hstore(cast(self.test_table.c.id, Text), "3")), + ), + ( + lambda self: self.hashcol.concat( + hstore(cast(self.test_table.c.id, Text), "3") + ), ( "test_table.hash || hstore(CAST(test_table.id AS TEXT), " "%(hstore_1)s) AS anon_1" ), True, - ) - - def test_cols_concat_op(self): - self._test_cols( - hstore("foo", "bar") + self.hashcol, + ), + ( + lambda self: hstore("foo", "bar") + self.hashcol, "hstore(%(hstore_1)s, %(hstore_2)s) || test_table.hash AS anon_1", True, - ) - - def test_cols_concat_get(self): - self._test_cols( - (self.hashcol + self.hashcol)["foo"], + ), + ( + lambda self: (self.hashcol + self.hashcol)["foo"], "(test_table.hash || test_table.hash) -> %(param_1)s AS anon_1", - ) - - def test_cols_against_is(self): - self._test_cols( - self.hashcol["foo"] != None, # noqa + True, + ), + ( + lambda self: self.hashcol["foo"] != None, # noqa "(test_table.hash -> %(hash_1)s) IS NOT NULL AS anon_1", - ) - - def test_cols_keys(self): - self._test_cols( + True, + ), + ( # hide from 2to3 - getattr(self.hashcol, "keys")(), + lambda self: getattr(self.hashcol, "keys")(), "akeys(test_table.hash) AS akeys_1", True, - ) - - def test_cols_vals(self): - self._test_cols( - self.hashcol.vals(), "avals(test_table.hash) AS avals_1", True - ) - - def test_cols_array(self): - self._test_cols( - self.hashcol.array(), + ), + ( + lambda self: self.hashcol.vals(), + "avals(test_table.hash) AS avals_1", + True, + ), + ( + lambda self: self.hashcol.array(), "hstore_to_array(test_table.hash) AS hstore_to_array_1", True, - ) - - def test_cols_matrix(self): - self._test_cols( - self.hashcol.matrix(), + ), + ( + lambda self: self.hashcol.matrix(), "hstore_to_matrix(test_table.hash) AS hstore_to_matrix_1", True, + ), + ) + def test_cols(self, colclause_fn, expected, from_): + colclause = colclause_fn(self) + stmt = select([colclause]) + self.assert_compile( + stmt, + ("SELECT %s" + (" FROM test_table" if from_ else "")) % expected, ) @@ -2850,7 +2748,36 @@ class JSONTest(AssertsCompiledSQL, fixtures.TestBase): ) self.jsoncol = self.test_table.c.test_column - def _test_where(self, whereclause, expected): + @testing.combinations( + ( + lambda self: self.jsoncol["bar"] == None, # noqa + "(test_table.test_column -> %(test_column_1)s) IS NULL", + ), + ( + lambda self: self.jsoncol[("foo", 1)] == None, # noqa + "(test_table.test_column #> %(test_column_1)s) IS NULL", + ), + ( + lambda self: self.jsoncol["bar"].astext == None, # noqa + "(test_table.test_column ->> %(test_column_1)s) IS NULL", + ), + ( + lambda self: self.jsoncol["bar"].astext.cast(Integer) == 5, + "CAST((test_table.test_column ->> %(test_column_1)s) AS INTEGER) " + "= %(param_1)s", + ), + ( + lambda self: self.jsoncol["bar"].cast(Integer) == 5, + "CAST((test_table.test_column -> %(test_column_1)s) AS INTEGER) " + "= %(param_1)s", + ), + ( + lambda self: self.jsoncol[("foo", 1)].astext == None, # noqa + "(test_table.test_column #>> %(test_column_1)s) IS NULL", + ), + ) + def test_where(self, whereclause_fn, expected): + whereclause = whereclause_fn(self) stmt = select([self.test_table]).where(whereclause) self.assert_compile( stmt, @@ -2858,27 +2785,6 @@ class JSONTest(AssertsCompiledSQL, fixtures.TestBase): "WHERE %s" % expected, ) - def _test_cols(self, colclause, expected, from_=True): - stmt = select([colclause]) - self.assert_compile( - stmt, - ("SELECT %s" + (" FROM test_table" if from_ else "")) % expected, - ) - - # This test is a bit misleading -- in real life you will need to cast to - # do anything - def test_where_getitem(self): - self._test_where( - self.jsoncol["bar"] == None, # noqa - "(test_table.test_column -> %(test_column_1)s) IS NULL", - ) - - def test_where_path(self): - self._test_where( - self.jsoncol[("foo", 1)] == None, # noqa - "(test_table.test_column #> %(test_column_1)s) IS NULL", - ) - def test_path_typing(self): col = column("x", JSON()) is_(col["q"].type._type_affinity, types.JSON) @@ -2898,38 +2804,20 @@ class JSONTest(AssertsCompiledSQL, fixtures.TestBase): is_(col["q"]["p"].astext.type.__class__, MyType) - def test_where_getitem_as_text(self): - self._test_where( - self.jsoncol["bar"].astext == None, # noqa - "(test_table.test_column ->> %(test_column_1)s) IS NULL", - ) - - def test_where_getitem_astext_cast(self): - self._test_where( - self.jsoncol["bar"].astext.cast(Integer) == 5, - "CAST((test_table.test_column ->> %(test_column_1)s) AS INTEGER) " - "= %(param_1)s", - ) - - def test_where_getitem_json_cast(self): - self._test_where( - self.jsoncol["bar"].cast(Integer) == 5, - "CAST((test_table.test_column -> %(test_column_1)s) AS INTEGER) " - "= %(param_1)s", - ) - - def test_where_path_as_text(self): - self._test_where( - self.jsoncol[("foo", 1)].astext == None, # noqa - "(test_table.test_column #>> %(test_column_1)s) IS NULL", - ) - - def test_cols_get(self): - self._test_cols( - self.jsoncol["foo"], + @testing.combinations( + ( + lambda self: self.jsoncol["foo"], "test_table.test_column -> %(test_column_1)s AS anon_1", True, ) + ) + def test_cols(self, colclause_fn, expected, from_): + colclause = colclause_fn(self) + stmt = select([colclause]) + self.assert_compile( + stmt, + ("SELECT %s" + (" FROM test_table" if from_ else "")) % expected, + ) class JSONRoundTripTest(fixtures.TablesTest): @@ -3292,40 +3180,35 @@ class JSONBTest(JSONTest): ) self.jsoncol = self.test_table.c.test_column - # Note - add fixture data for arrays [] - - def test_where_has_key(self): - self._test_where( + @testing.combinations( + ( # hide from 2to3 - getattr(self.jsoncol, "has_key")("data"), + lambda self: getattr(self.jsoncol, "has_key")("data"), "test_table.test_column ? %(test_column_1)s", - ) - - def test_where_has_all(self): - self._test_where( - self.jsoncol.has_all( + ), + ( + lambda self: self.jsoncol.has_all( {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}} ), "test_table.test_column ?& %(test_column_1)s", - ) - - def test_where_has_any(self): - self._test_where( - self.jsoncol.has_any(postgresql.array(["name", "data"])), + ), + ( + lambda self: self.jsoncol.has_any( + postgresql.array(["name", "data"]) + ), "test_table.test_column ?| ARRAY[%(param_1)s, %(param_2)s]", - ) - - def test_where_contains(self): - self._test_where( - self.jsoncol.contains({"k1": "r1v1"}), + ), + ( + lambda self: self.jsoncol.contains({"k1": "r1v1"}), "test_table.test_column @> %(test_column_1)s", - ) - - def test_where_contained_by(self): - self._test_where( - self.jsoncol.contained_by({"foo": "1", "bar": None}), + ), + ( + lambda self: self.jsoncol.contained_by({"foo": "1", "bar": None}), "test_table.test_column <@ %(test_column_1)s", - ) + ), + ) + def test_where(self, whereclause_fn, expected): + super(JSONBTest, self).test_where(whereclause_fn, expected) class JSONBRoundTripTest(JSONRoundTripTest):