if not hasattr(self.compiled, 'bind_names'):
return
- types = dict(
- (self.compiled.bind_names[bindparam], bindparam.type)
- for bindparam in self.compiled.bind_names)
+ key_to_dbapi_type = {}
+ for bindparam in self.compiled.bind_names:
+ key = self.compiled.bind_names[bindparam]
+ dialect_impl = bindparam.type.dialect_impl(self.dialect)
+ dialect_impl_cls = type(dialect_impl)
+ dbtype = dialect_impl.get_dbapi_type(self.dialect.dbapi)
+ if dbtype is not None and (
+ not exclude_types or dbtype not in exclude_types and
+ dialect_impl_cls not in exclude_types
+ ) and (
+ not include_types or dbtype in include_types or
+ dialect_impl_cls in include_types
+ ):
+ key_to_dbapi_type[key] = dbtype
if self.dialect.positional:
inputsizes = []
for key in self.compiled.positiontup:
- typeengine = types[key]
- dbtype = typeengine.dialect_impl(self.dialect).\
- get_dbapi_type(self.dialect.dbapi)
- if dbtype is not None and \
- (not exclude_types or dbtype not in exclude_types) and \
- (not include_types or dbtype in include_types):
+ if key in key_to_dbapi_type:
+ dbtype = key_to_dbapi_type[key]
if key in self._expanded_parameters:
inputsizes.extend(
[dbtype] * len(self._expanded_parameters[key]))
else:
inputsizes = {}
for key in self.compiled.bind_names.values():
- typeengine = types[key]
- dbtype = typeengine.dialect_impl(self.dialect).\
- get_dbapi_type(self.dialect.dbapi)
- if dbtype is not None and \
- (not exclude_types or dbtype not in exclude_types) and \
- (not include_types or dbtype in include_types):
+ if key in key_to_dbapi_type:
+ dbtype = key_to_dbapi_type[key]
if translate:
# TODO: this part won't work w/ the
# expanded_parameters feature, e.g. for cx_oracle
from sqlalchemy import testing
from sqlalchemy import Integer, Text, LargeBinary, Unicode, UniqueConstraint,\
Index, MetaData, select, inspect, ForeignKey, String, func, \
- TypeDecorator, bindparam, Numeric, TIMESTAMP, CHAR, text, \
+ TypeDecorator, bindparam, Numeric, TIMESTAMP, CHAR, text, SmallInteger, \
literal_column, VARCHAR, create_engine, Date, NVARCHAR, \
ForeignKeyConstraint, Sequence, Float, DateTime, cast, UnicodeText, \
union, except_, type_coerce, or_, outerjoin, DATE, NCHAR, outparam, \
import os
from sqlalchemy import sql
from sqlalchemy.testing.mock import Mock
+from sqlalchemy.testing import mock
class DialectTypesTest(fixtures.TestBase, AssertsCompiledSQL):
assert type(test_exp) is type(exp)
+class SetInputSizesTest(fixtures.TestBase):
+ __only_on__ = 'oracle+cx_oracle'
+ __backend__ = True
+
+ @testing.provide_metadata
+ def _test_setinputsizes(self, datatype, value, sis_value):
+ m = self.metadata
+ t1 = Table('t1', m, Column('foo', datatype))
+ t1.create()
+
+ class CursorWrapper(object):
+ # cx_oracle cursor can't be modified so we have to
+ # invent a whole wrapping scheme
+
+ def __init__(self, connection_fairy):
+ self.cursor = connection_fairy.connection.cursor()
+ self.mock = mock.Mock()
+ connection_fairy.info['mock'] = self.mock
+
+ def setinputsizes(self, *arg, **kw):
+ self.mock.setinputsizes(*arg, **kw)
+ self.cursor.setinputsizes(*arg, **kw)
+
+ def __getattr__(self, key):
+ return getattr(self.cursor, key)
+
+ with testing.db.connect() as conn:
+ connection_fairy = conn.connection
+ with mock.patch.object(
+ connection_fairy, "cursor",
+ lambda: CursorWrapper(connection_fairy)
+ ):
+ conn.execute(
+ t1.insert(), {"foo": value}
+ )
+
+ if sis_value:
+ eq_(
+ conn.info['mock'].mock_calls,
+ [mock.call.setinputsizes(foo=sis_value)]
+ )
+ else:
+ eq_(
+ conn.info['mock'].mock_calls,
+ [mock.call.setinputsizes()]
+ )
+
+ def test_smallint_setinputsizes(self):
+ self._test_setinputsizes(
+ SmallInteger, 25, testing.db.dialect.dbapi.NUMBER)
+
+ def test_int_setinputsizes(self):
+ self._test_setinputsizes(
+ Integer, 25, testing.db.dialect.dbapi.NUMBER)
+
+ def test_numeric_setinputsizes(self):
+ self._test_setinputsizes(
+ Numeric(10, 8), decimal.Decimal("25.34534"), None)
+
+ def test_float_setinputsizes(self):
+ self._test_setinputsizes(Float(15), 25.34534, None)
+
+ def test_unicode(self):
+ self._test_setinputsizes(
+ Unicode(30), u("test"), testing.db.dialect.dbapi.NCHAR)
+
+ def test_string(self):
+ self._test_setinputsizes(String(30), "test", None)
+
+ def test_char(self):
+ self._test_setinputsizes(
+ CHAR(30), "test", testing.db.dialect.dbapi.FIXED_CHAR)
+
+ def test_nchar(self):
+ self._test_setinputsizes(
+ NCHAR(30), u("test"), testing.db.dialect.dbapi.NCHAR)
+
+ def test_long(self):
+ self._test_setinputsizes(
+ oracle.LONG(), "test", None)