]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
call setinputsizes() for integer types
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 18 May 2018 16:51:40 +0000 (12:51 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 18 May 2018 23:29:16 +0000 (19:29 -0400)
Altered the Oracle dialect such that when an :class:`.Integer` type is in
use, the cx_Oracle.NUMERIC type is set up for setinputsizes().  In
SQLAlchemy 1.1 and earlier, cx_Oracle.NUMERIC was passed for all numeric
types unconditionally, and in 1.2 this was removed to allow for better
numeric precision.  However, for integers, some database/client setups
will fail to coerce boolean values True/False into integers which introduces
regressive behavior when using SQLAlchemy 1.2.  Overall, the setinputsizes
logic seems like it will need a lot more flexibility going forward so this
is a start for that.

Change-Id: Ida80cc2c2c37ffc0e05da4b5df2dadfab55a01f2
Fixes: #4259
doc/build/changelog/unreleased_12/4259.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/engine/default.py
test/dialect/oracle/test_types.py

diff --git a/doc/build/changelog/unreleased_12/4259.rst b/doc/build/changelog/unreleased_12/4259.rst
new file mode 100644 (file)
index 0000000..ee3abd3
--- /dev/null
@@ -0,0 +1,14 @@
+.. change::
+    :tags: bug, oracle
+    :tickets: 4259
+    :versions: 1.3.0b1
+
+    Altered the Oracle dialect such that when an :class:`.Integer` type is in
+    use, the cx_Oracle.NUMERIC type is set up for setinputsizes().  In
+    SQLAlchemy 1.1 and earlier, cx_Oracle.NUMERIC was passed for all numeric
+    types unconditionally, and in 1.2 this was removed to allow for better
+    numeric precision.  However, for integers, some database/client setups
+    will fail to coerce boolean values True/False into integers which introduces
+    regressive behavior when using SQLAlchemy 1.2.  Overall, the setinputsizes
+    logic seems like it will need a lot more flexibility going forward so this
+    is a start for that.
index 1605c3a67b95a0cb293260ff2548e0c9b21d3a6e..0bd682d19eeff75b1b82c9a5f7a18a391799ded2 100644 (file)
@@ -240,6 +240,7 @@ class _OracleInteger(sqltypes.Integer):
         return handler
 
 
+
 class _OracleNumeric(sqltypes.Numeric):
     is_number = False
 
@@ -326,8 +327,6 @@ class _OracleNUMBER(_OracleNumeric):
     is_number = True
 
 
-
-
 class _OracleDate(sqltypes.Date):
     def bind_processor(self, dialect):
         return None
@@ -595,7 +594,7 @@ class OracleDialect_cx_oracle(OracleDialect):
 
     driver = "cx_oracle"
 
-    colspecs = colspecs = {
+    colspecs = {
         sqltypes.Numeric: _OracleNumeric,
         sqltypes.Float: _OracleNumeric,
         sqltypes.Integer: _OracleInteger,
@@ -654,7 +653,8 @@ class OracleDialect_cx_oracle(OracleDialect):
             self._include_setinputsizes = {
                 cx_Oracle.NCLOB, cx_Oracle.CLOB, cx_Oracle.LOB,
                 cx_Oracle.NCHAR, cx_Oracle.FIXED_NCHAR,
-                cx_Oracle.BLOB, cx_Oracle.FIXED_CHAR, cx_Oracle.TIMESTAMP
+                cx_Oracle.BLOB, cx_Oracle.FIXED_CHAR, cx_Oracle.TIMESTAMP,
+                _OracleInteger
             }
 
         self._is_cx_oracle_6 = self.cx_oracle_ver >= (6, )
index 099e694b6841738434c73483e8e9ee29f1818152..ea806deaeee873dd976b9e440ecfe0ac91d71b56 100644 (file)
@@ -1126,19 +1126,26 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         if not hasattr(self.compiled, 'bind_names'):
             return
 
-        types = dict(
-            (self.compiled.bind_names[bindparam], bindparam.type)
-            for bindparam in self.compiled.bind_names)
+        key_to_dbapi_type = {}
+        for bindparam in self.compiled.bind_names:
+            key = self.compiled.bind_names[bindparam]
+            dialect_impl = bindparam.type.dialect_impl(self.dialect)
+            dialect_impl_cls = type(dialect_impl)
+            dbtype = dialect_impl.get_dbapi_type(self.dialect.dbapi)
+            if dbtype is not None and (
+                not exclude_types or dbtype not in exclude_types and
+                dialect_impl_cls not in exclude_types
+            ) and (
+                not include_types or dbtype in include_types or
+                dialect_impl_cls in include_types
+            ):
+                key_to_dbapi_type[key] = dbtype
 
         if self.dialect.positional:
             inputsizes = []
             for key in self.compiled.positiontup:
-                typeengine = types[key]
-                dbtype = typeengine.dialect_impl(self.dialect).\
-                    get_dbapi_type(self.dialect.dbapi)
-                if dbtype is not None and \
-                        (not exclude_types or dbtype not in exclude_types) and \
-                        (not include_types or dbtype in include_types):
+                if key in key_to_dbapi_type:
+                    dbtype = key_to_dbapi_type[key]
                     if key in self._expanded_parameters:
                         inputsizes.extend(
                             [dbtype] * len(self._expanded_parameters[key]))
@@ -1152,12 +1159,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         else:
             inputsizes = {}
             for key in self.compiled.bind_names.values():
-                typeengine = types[key]
-                dbtype = typeengine.dialect_impl(self.dialect).\
-                    get_dbapi_type(self.dialect.dbapi)
-                if dbtype is not None and \
-                        (not exclude_types or dbtype not in exclude_types) and \
-                        (not include_types or dbtype in include_types):
+                if key in key_to_dbapi_type:
+                    dbtype = key_to_dbapi_type[key]
                     if translate:
                         # TODO: this part won't work w/ the
                         # expanded_parameters feature, e.g. for cx_oracle
index 394fc29a85ff39a489ccad0d9941b7bbdf8daedc..a13a578b608bf26747cd072c25b4107a388cffcc 100644 (file)
@@ -10,7 +10,7 @@ from sqlalchemy.testing import (fixtures,
 from sqlalchemy import testing
 from sqlalchemy import Integer, Text, LargeBinary, Unicode, UniqueConstraint,\
     Index, MetaData, select, inspect, ForeignKey, String, func, \
-    TypeDecorator, bindparam, Numeric, TIMESTAMP, CHAR, text, \
+    TypeDecorator, bindparam, Numeric, TIMESTAMP, CHAR, text, SmallInteger, \
     literal_column, VARCHAR, create_engine, Date, NVARCHAR, \
     ForeignKeyConstraint, Sequence, Float, DateTime, cast, UnicodeText, \
     union, except_, type_coerce, or_, outerjoin, DATE, NCHAR, outparam, \
@@ -28,6 +28,7 @@ import datetime
 import os
 from sqlalchemy import sql
 from sqlalchemy.testing.mock import Mock
+from sqlalchemy.testing import mock
 
 
 class DialectTypesTest(fixtures.TestBase, AssertsCompiledSQL):
@@ -824,3 +825,83 @@ class EuroNumericTest(fixtures.TestBase):
                 assert type(test_exp) is type(exp)
 
 
+class SetInputSizesTest(fixtures.TestBase):
+    __only_on__ = 'oracle+cx_oracle'
+    __backend__ = True
+
+    @testing.provide_metadata
+    def _test_setinputsizes(self, datatype, value, sis_value):
+        m = self.metadata
+        t1 = Table('t1', m, Column('foo', datatype))
+        t1.create()
+
+        class CursorWrapper(object):
+            # cx_oracle cursor can't be modified so we have to
+            # invent a whole wrapping scheme
+
+            def __init__(self, connection_fairy):
+                self.cursor = connection_fairy.connection.cursor()
+                self.mock = mock.Mock()
+                connection_fairy.info['mock'] = self.mock
+
+            def setinputsizes(self, *arg, **kw):
+                self.mock.setinputsizes(*arg, **kw)
+                self.cursor.setinputsizes(*arg, **kw)
+
+            def __getattr__(self, key):
+                return getattr(self.cursor, key)
+
+        with testing.db.connect() as conn:
+            connection_fairy = conn.connection
+            with mock.patch.object(
+                    connection_fairy, "cursor",
+                    lambda: CursorWrapper(connection_fairy)
+            ):
+                conn.execute(
+                    t1.insert(), {"foo": value}
+                )
+
+            if sis_value:
+                eq_(
+                    conn.info['mock'].mock_calls,
+                    [mock.call.setinputsizes(foo=sis_value)]
+                )
+            else:
+                eq_(
+                    conn.info['mock'].mock_calls,
+                    [mock.call.setinputsizes()]
+                )
+
+    def test_smallint_setinputsizes(self):
+        self._test_setinputsizes(
+            SmallInteger, 25, testing.db.dialect.dbapi.NUMBER)
+
+    def test_int_setinputsizes(self):
+        self._test_setinputsizes(
+            Integer, 25, testing.db.dialect.dbapi.NUMBER)
+
+    def test_numeric_setinputsizes(self):
+        self._test_setinputsizes(
+            Numeric(10, 8), decimal.Decimal("25.34534"), None)
+
+    def test_float_setinputsizes(self):
+        self._test_setinputsizes(Float(15), 25.34534, None)
+
+    def test_unicode(self):
+        self._test_setinputsizes(
+            Unicode(30), u("test"), testing.db.dialect.dbapi.NCHAR)
+
+    def test_string(self):
+        self._test_setinputsizes(String(30), "test", None)
+
+    def test_char(self):
+        self._test_setinputsizes(
+            CHAR(30), "test", testing.db.dialect.dbapi.FIXED_CHAR)
+
+    def test_nchar(self):
+        self._test_setinputsizes(
+            NCHAR(30), u("test"), testing.db.dialect.dbapi.NCHAR)
+
+    def test_long(self):
+        self._test_setinputsizes(
+            oracle.LONG(), "test", None)