From: Mike Bayer Date: Fri, 18 May 2018 16:51:40 +0000 (-0400) Subject: call setinputsizes() for integer types X-Git-Tag: rel_1_3_0b1~179 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c7ae04d1c5c4aa6c6099584ae386d6ab9ef7b290;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git call setinputsizes() for integer types Altered the Oracle dialect such that when an :class:`.Integer` type is in use, the cx_Oracle.NUMERIC type is set up for setinputsizes(). In SQLAlchemy 1.1 and earlier, cx_Oracle.NUMERIC was passed for all numeric types unconditionally, and in 1.2 this was removed to allow for better numeric precision. However, for integers, some database/client setups will fail to coerce boolean values True/False into integers which introduces regressive behavior when using SQLAlchemy 1.2. Overall, the setinputsizes logic seems like it will need a lot more flexibility going forward so this is a start for that. Change-Id: Ida80cc2c2c37ffc0e05da4b5df2dadfab55a01f2 Fixes: #4259 --- diff --git a/doc/build/changelog/unreleased_12/4259.rst b/doc/build/changelog/unreleased_12/4259.rst new file mode 100644 index 0000000000..ee3abd38f2 --- /dev/null +++ b/doc/build/changelog/unreleased_12/4259.rst @@ -0,0 +1,14 @@ +.. change:: + :tags: bug, oracle + :tickets: 4259 + :versions: 1.3.0b1 + + Altered the Oracle dialect such that when an :class:`.Integer` type is in + use, the cx_Oracle.NUMERIC type is set up for setinputsizes(). In + SQLAlchemy 1.1 and earlier, cx_Oracle.NUMERIC was passed for all numeric + types unconditionally, and in 1.2 this was removed to allow for better + numeric precision. However, for integers, some database/client setups + will fail to coerce boolean values True/False into integers which introduces + regressive behavior when using SQLAlchemy 1.2. Overall, the setinputsizes + logic seems like it will need a lot more flexibility going forward so this + is a start for that. diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 1605c3a67b..0bd682d19e 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -240,6 +240,7 @@ class _OracleInteger(sqltypes.Integer): return handler + class _OracleNumeric(sqltypes.Numeric): is_number = False @@ -326,8 +327,6 @@ class _OracleNUMBER(_OracleNumeric): is_number = True - - class _OracleDate(sqltypes.Date): def bind_processor(self, dialect): return None @@ -595,7 +594,7 @@ class OracleDialect_cx_oracle(OracleDialect): driver = "cx_oracle" - colspecs = colspecs = { + colspecs = { sqltypes.Numeric: _OracleNumeric, sqltypes.Float: _OracleNumeric, sqltypes.Integer: _OracleInteger, @@ -654,7 +653,8 @@ class OracleDialect_cx_oracle(OracleDialect): self._include_setinputsizes = { cx_Oracle.NCLOB, cx_Oracle.CLOB, cx_Oracle.LOB, cx_Oracle.NCHAR, cx_Oracle.FIXED_NCHAR, - cx_Oracle.BLOB, cx_Oracle.FIXED_CHAR, cx_Oracle.TIMESTAMP + cx_Oracle.BLOB, cx_Oracle.FIXED_CHAR, cx_Oracle.TIMESTAMP, + _OracleInteger } self._is_cx_oracle_6 = self.cx_oracle_ver >= (6, ) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 099e694b68..ea806deaee 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1126,19 +1126,26 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if not hasattr(self.compiled, 'bind_names'): return - types = dict( - (self.compiled.bind_names[bindparam], bindparam.type) - for bindparam in self.compiled.bind_names) + key_to_dbapi_type = {} + for bindparam in self.compiled.bind_names: + key = self.compiled.bind_names[bindparam] + dialect_impl = bindparam.type.dialect_impl(self.dialect) + dialect_impl_cls = type(dialect_impl) + dbtype = dialect_impl.get_dbapi_type(self.dialect.dbapi) + if dbtype is not None and ( + not exclude_types or dbtype not in exclude_types and + dialect_impl_cls not in exclude_types + ) and ( + not include_types or dbtype in include_types or + dialect_impl_cls in include_types + ): + key_to_dbapi_type[key] = dbtype if self.dialect.positional: inputsizes = [] for key in self.compiled.positiontup: - typeengine = types[key] - dbtype = typeengine.dialect_impl(self.dialect).\ - get_dbapi_type(self.dialect.dbapi) - if dbtype is not None and \ - (not exclude_types or dbtype not in exclude_types) and \ - (not include_types or dbtype in include_types): + if key in key_to_dbapi_type: + dbtype = key_to_dbapi_type[key] if key in self._expanded_parameters: inputsizes.extend( [dbtype] * len(self._expanded_parameters[key])) @@ -1152,12 +1159,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): else: inputsizes = {} for key in self.compiled.bind_names.values(): - typeengine = types[key] - dbtype = typeengine.dialect_impl(self.dialect).\ - get_dbapi_type(self.dialect.dbapi) - if dbtype is not None and \ - (not exclude_types or dbtype not in exclude_types) and \ - (not include_types or dbtype in include_types): + if key in key_to_dbapi_type: + dbtype = key_to_dbapi_type[key] if translate: # TODO: this part won't work w/ the # expanded_parameters feature, e.g. for cx_oracle diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index 394fc29a85..a13a578b60 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -10,7 +10,7 @@ from sqlalchemy.testing import (fixtures, from sqlalchemy import testing from sqlalchemy import Integer, Text, LargeBinary, Unicode, UniqueConstraint,\ Index, MetaData, select, inspect, ForeignKey, String, func, \ - TypeDecorator, bindparam, Numeric, TIMESTAMP, CHAR, text, \ + TypeDecorator, bindparam, Numeric, TIMESTAMP, CHAR, text, SmallInteger, \ literal_column, VARCHAR, create_engine, Date, NVARCHAR, \ ForeignKeyConstraint, Sequence, Float, DateTime, cast, UnicodeText, \ union, except_, type_coerce, or_, outerjoin, DATE, NCHAR, outparam, \ @@ -28,6 +28,7 @@ import datetime import os from sqlalchemy import sql from sqlalchemy.testing.mock import Mock +from sqlalchemy.testing import mock class DialectTypesTest(fixtures.TestBase, AssertsCompiledSQL): @@ -824,3 +825,83 @@ class EuroNumericTest(fixtures.TestBase): assert type(test_exp) is type(exp) +class SetInputSizesTest(fixtures.TestBase): + __only_on__ = 'oracle+cx_oracle' + __backend__ = True + + @testing.provide_metadata + def _test_setinputsizes(self, datatype, value, sis_value): + m = self.metadata + t1 = Table('t1', m, Column('foo', datatype)) + t1.create() + + class CursorWrapper(object): + # cx_oracle cursor can't be modified so we have to + # invent a whole wrapping scheme + + def __init__(self, connection_fairy): + self.cursor = connection_fairy.connection.cursor() + self.mock = mock.Mock() + connection_fairy.info['mock'] = self.mock + + def setinputsizes(self, *arg, **kw): + self.mock.setinputsizes(*arg, **kw) + self.cursor.setinputsizes(*arg, **kw) + + def __getattr__(self, key): + return getattr(self.cursor, key) + + with testing.db.connect() as conn: + connection_fairy = conn.connection + with mock.patch.object( + connection_fairy, "cursor", + lambda: CursorWrapper(connection_fairy) + ): + conn.execute( + t1.insert(), {"foo": value} + ) + + if sis_value: + eq_( + conn.info['mock'].mock_calls, + [mock.call.setinputsizes(foo=sis_value)] + ) + else: + eq_( + conn.info['mock'].mock_calls, + [mock.call.setinputsizes()] + ) + + def test_smallint_setinputsizes(self): + self._test_setinputsizes( + SmallInteger, 25, testing.db.dialect.dbapi.NUMBER) + + def test_int_setinputsizes(self): + self._test_setinputsizes( + Integer, 25, testing.db.dialect.dbapi.NUMBER) + + def test_numeric_setinputsizes(self): + self._test_setinputsizes( + Numeric(10, 8), decimal.Decimal("25.34534"), None) + + def test_float_setinputsizes(self): + self._test_setinputsizes(Float(15), 25.34534, None) + + def test_unicode(self): + self._test_setinputsizes( + Unicode(30), u("test"), testing.db.dialect.dbapi.NCHAR) + + def test_string(self): + self._test_setinputsizes(String(30), "test", None) + + def test_char(self): + self._test_setinputsizes( + CHAR(30), "test", testing.db.dialect.dbapi.FIXED_CHAR) + + def test_nchar(self): + self._test_setinputsizes( + NCHAR(30), u("test"), testing.db.dialect.dbapi.NCHAR) + + def test_long(self): + self._test_setinputsizes( + oracle.LONG(), "test", None)