From: Mike Bayer Date: Thu, 15 Oct 2020 22:18:03 +0000 (-0400) Subject: Genericize setinputsizes and support pyodbc X-Git-Tag: rel_1_4_0b1~25 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=87c24c498cb660e7a8d7d4dd5f630b967f79d3c8;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Genericize setinputsizes and support pyodbc Reworked the "setinputsizes()" set of dialect hooks to be correctly extensible for any arbirary DBAPI, by allowing dialects individual hooks that may invoke cursor.setinputsizes() in the appropriate style for that DBAPI. In particular this is intended to support pyodbc's style of usage which is fundamentally different from that of cx_Oracle. Added support for pyodbc. Fixes: #5649 Change-Id: I9f1794f8368bf3663a286932cfe3992dae244a10 --- diff --git a/doc/build/changelog/unreleased_14/5649.rst b/doc/build/changelog/unreleased_14/5649.rst new file mode 100644 index 0000000000..20e69c4c3d --- /dev/null +++ b/doc/build/changelog/unreleased_14/5649.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, engine, pyodbc + :tickets: 5649 + + Reworked the "setinputsizes()" set of dialect hooks to be correctly + extensible for any arbirary DBAPI, by allowing dialects individual hooks + that may invoke cursor.setinputsizes() in the appropriate style for that + DBAPI. In particular this is intended to support pyodbc's style of usage + which is fundamentally different from that of cx_Oracle. Added support + for pyodbc. + diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index e1a7c99f4d..7801613048 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -24,6 +24,8 @@ class PyODBCConnector(Connector): supports_native_decimal = True default_paramstyle = "named" + use_setinputsizes = True + # for non-DSN connections, this *may* be used to # hold the desired driver name pyodbc_driver_name = None @@ -155,6 +157,21 @@ class PyODBCConnector(Connector): version.append(n) return tuple(version) + def do_set_input_sizes(self, cursor, list_of_tuples, context): + # the rules for these types seems a little strange, as you can pass + # non-tuples as well as tuples, however it seems to assume "0" + # for the subsequent values if you don't pass a tuple which fails + # for types such as pyodbc.SQL_WLONGVARCHAR, which is the datatype + # that ticket #5649 is targeting. + cursor.setinputsizes( + [ + (dbtype, None, None) + if not isinstance(dbtype, tuple) + else dbtype + for key, dbtype, sqltype in list_of_tuples + ] + ) + def set_isolation_level(self, connection, level): # adjust for ConnectionFairy being present # allows attribute set e.g. "connection.autocommit = True" diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index d1b69100f7..7bde19090d 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -687,15 +687,12 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): if self.compiled._quoted_bind_names: self._setup_quoted_bind_names() - self.set_input_sizes( - self.compiled._quoted_bind_names, - include_types=self.dialect._include_setinputsizes, - ) - self._generate_out_parameter_vars() self._generate_cursor_outputtype_handler() + self.include_set_input_sizes = self.dialect._include_setinputsizes + def post_exec(self): if self.compiled and self.out_parameters and self.compiled.returning: # create a fake cursor result from the out parameters. unlike @@ -746,6 +743,8 @@ class OracleDialect_cx_oracle(OracleDialect): supports_unicode_statements = True supports_unicode_binds = True + use_setinputsizes = True + driver = "cx_oracle" colspecs = { @@ -1172,6 +1171,35 @@ class OracleDialect_cx_oracle(OracleDialect): if oci_prepared: self.do_commit(connection.connection) + def do_set_input_sizes(self, cursor, list_of_tuples, context): + if self.positional: + # not usually used, here to support if someone is modifying + # the dialect to use positional style + cursor.setinputsizes( + *[dbtype for key, dbtype, sqltype in list_of_tuples] + ) + else: + collection = ( + (key, dbtype) + for key, dbtype, sqltype in list_of_tuples + if dbtype + ) + if context and context.compiled: + quoted_bind_names = context.compiled._quoted_bind_names + collection = ( + (quoted_bind_names.get(key, key), dbtype) + for key, dbtype in collection + ) + + if not self.supports_unicode_binds: + # oracle 8 only + collection = ( + (self.dialect._encoder(key)[0], dbtype) + for key, dbtype in collection + ) + + cursor.setinputsizes(**{key: dbtype for key, dbtype in collection}) + def do_recover_twophase(self, connection): connection.info.pop("cx_oracle_prepared", None) diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index a4937d0d29..7d679731b4 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -245,7 +245,7 @@ class PGExecutionContext_asyncpg(PGExecutionContext): # we have to exclude ENUM because "enum" not really a "type" # we can cast to, it has to be the name of the type itself. # for now we just omit it from casting - self.set_input_sizes(exclude_types={AsyncAdapt_asyncpg_dbapi.ENUM}) + self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM} def create_server_side_cursor(self): return self._dbapi_connection.cursor(server_side=True) @@ -687,6 +687,8 @@ class PGDialect_asyncpg(PGDialect): statement_compiler = PGCompiler_asyncpg preparer = PGIdentifierPreparer_asyncpg + use_setinputsizes = True + use_native_uuid = True colspecs = util.update_copy( @@ -787,6 +789,20 @@ class PGDialect_asyncpg(PGDialect): e, self.dbapi.InterfaceError ) and "connection is closed" in str(e) + def do_set_input_sizes(self, cursor, list_of_tuples, context): + if self.positional: + cursor.setinputsizes( + *[dbtype for key, dbtype, sqltype in list_of_tuples] + ) + else: + cursor.setinputsizes( + **{ + key: dbtype + for key, dbtype, sqltype in list_of_tuples + if dbtype + } + ) + def on_connect(self): super_connect = super(PGDialect_asyncpg, self).on_connect() diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index b2faa4243b..4392491576 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -234,8 +234,6 @@ class PGExecutionContext_pg8000(PGExecutionContext): if not self.compiled: return - self.set_input_sizes() - class PGCompiler_pg8000(PGCompiler): def visit_mod_binary(self, binary, operator, **kw): @@ -265,6 +263,8 @@ class PGDialect_pg8000(PGDialect): statement_compiler = PGCompiler_pg8000 preparer = PGIdentifierPreparer_pg8000 + use_setinputsizes = True + # reversed as of pg8000 1.16.6. 1.16.5 and lower # are no longer compatible description_encoding = None @@ -407,6 +407,20 @@ class PGDialect_pg8000(PGDialect): cursor.execute("COMMIT") cursor.close() + def do_set_input_sizes(self, cursor, list_of_tuples, context): + if self.positional: + cursor.setinputsizes( + *[dbtype for key, dbtype, sqltype in list_of_tuples] + ) + else: + cursor.setinputsizes( + **{ + key: dbtype + for key, dbtype, sqltype in list_of_tuples + if dbtype + } + ) + def do_begin_twophase(self, connection, xid): connection.connection.tpc_begin((0, xid, "")) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 4fbdec1452..9a5518a961 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1553,6 +1553,9 @@ class Connection(Connectable): context.pre_exec() + if dialect.use_setinputsizes: + context._set_input_sizes() + cursor, statement, parameters = ( context.cursor, context.statement, diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index ff29c3b9dd..d63cb4addd 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -56,6 +56,7 @@ class DefaultDialect(interfaces.Dialect): supports_alter = True supports_comments = False inline_comments = False + use_setinputsizes = False # the first value we'd get for an autoincrement # column. @@ -782,6 +783,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext): returned_default_rows = None execution_options = util.immutabledict() + include_set_input_sizes = None + exclude_set_input_sizes = None + cursor_fetch_strategy = _cursor._DEFAULT_FETCH cache_stats = None @@ -1477,9 +1481,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.compiled.postfetch ) - def set_input_sizes( - self, translate=None, include_types=None, exclude_types=None - ): + def _set_input_sizes(self): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. @@ -1488,14 +1490,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext): currently cx_oracle. """ - if self.isddl: - return None + if self.isddl or self.is_text: + return inputsizes = self.compiled._get_set_input_sizes_lookup( - translate=translate, - include_types=include_types, - exclude_types=exclude_types, + include_types=self.include_set_input_sizes, + exclude_types=self.exclude_set_input_sizes, ) + if inputsizes is None: return @@ -1506,82 +1508,52 @@ class DefaultExecutionContext(interfaces.ExecutionContext): ) if self.dialect.positional: - positional_inputsizes = [] - for key in self.compiled.positiontup: - bindparam = self.compiled.binds[key] - if bindparam in self.compiled.literal_execute_params: - continue - - if key in self._expanded_parameters: - if bindparam.type._is_tuple_type: - num = len(bindparam.type.types) - dbtypes = inputsizes[bindparam] - positional_inputsizes.extend( - [ - dbtypes[idx % num] - for idx, key in enumerate( - self._expanded_parameters[key] - ) - ] - ) - else: - dbtype = inputsizes.get(bindparam, None) - positional_inputsizes.extend( - dbtype for dbtype in self._expanded_parameters[key] - ) - else: - dbtype = inputsizes[bindparam] - positional_inputsizes.append(dbtype) - try: - self.cursor.setinputsizes(*positional_inputsizes) - except BaseException as e: - self.root_connection._handle_dbapi_exception( - e, None, None, None, self - ) + items = [ + (key, self.compiled.binds[key]) + for key in self.compiled.positiontup + ] else: - keyword_inputsizes = {} - for bindparam, key in self.compiled.bind_names.items(): - if bindparam in self.compiled.literal_execute_params: - continue - - if key in self._expanded_parameters: - if bindparam.type._is_tuple_type: - num = len(bindparam.type.types) - dbtypes = inputsizes[bindparam] - keyword_inputsizes.update( - [ - (key, dbtypes[idx % num]) - for idx, key in enumerate( - self._expanded_parameters[key] - ) - ] + items = [ + (key, bindparam) + for bindparam, key in self.compiled.bind_names.items() + ] + + generic_inputsizes = [] + for key, bindparam in items: + if bindparam in self.compiled.literal_execute_params: + continue + + if key in self._expanded_parameters: + if bindparam.type._is_tuple_type: + num = len(bindparam.type.types) + dbtypes = inputsizes[bindparam] + generic_inputsizes.extend( + ( + paramname, + dbtypes[idx % num], + bindparam.type.types[idx % num], ) - else: - dbtype = inputsizes.get(bindparam, None) - if dbtype is not None: - keyword_inputsizes.update( - (expand_key, dbtype) - for expand_key in self._expanded_parameters[ - key - ] - ) + for idx, paramname in enumerate( + self._expanded_parameters[key] + ) + ) else: dbtype = inputsizes.get(bindparam, None) - if dbtype is not None: - if translate: - # TODO: this part won't work w/ the - # expanded_parameters feature, e.g. for cx_oracle - # quoted bound names - key = translate.get(key, key) - if not self.dialect.supports_unicode_binds: - key = self.dialect._encoder(key)[0] - keyword_inputsizes[key] = dbtype - try: - self.cursor.setinputsizes(**keyword_inputsizes) - except BaseException as e: - self.root_connection._handle_dbapi_exception( - e, None, None, None, self - ) + generic_inputsizes.extend( + (paramname, dbtype, bindparam.type) + for paramname in self._expanded_parameters[key] + ) + else: + dbtype = inputsizes.get(bindparam, None) + generic_inputsizes.append((key, dbtype, bindparam.type)) + try: + self.dialect.do_set_input_sizes( + self.cursor, generic_inputsizes, self + ) + except BaseException as e: + self.root_connection._handle_dbapi_exception( + e, None, None, None, self + ) def _exec_default(self, column, default, type_): if default.is_sequence: diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py index 9f30a83ce8..ccc6c59685 100644 --- a/lib/sqlalchemy/engine/events.py +++ b/lib/sqlalchemy/engine/events.py @@ -792,10 +792,9 @@ class DialectEvents(event.Events): or a dictionary of string parameter keys to DBAPI type objects for a named bound parameter execution style. - Most dialects **do not use** this method at all; the only built-in - dialect which uses this hook is the cx_Oracle dialect. The hook here - is made available so as to allow customization of how datatypes are set - up with the cx_Oracle DBAPI. + The setinputsizes hook overall is only used for dialects which include + the flag ``use_setinputsizes=True``. Dialects which use this + include cx_Oracle, pg8000, asyncpg, and pyodbc dialects. .. versionadded:: 1.2.9 diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index b7bd3627bd..a7f71f5e5f 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -569,6 +569,21 @@ class Dialect(object): raise NotImplementedError() + def do_set_input_sizes(self, cursor, list_of_tuples, context): + """invoke the cursor.setinputsizes() method with appropriate arguments + + This hook is called if the dialect.use_inputsizes flag is set to True. + Parameter data is passed in a list of tuples (paramname, dbtype, + sqltype), where ``paramname`` is the key of the parameter in the + statement, ``dbtype`` is the DBAPI datatype and ``sqltype`` is the + SQLAlchemy type. The order of tuples is in the correct parameter order. + + .. versionadded:: 1.4 + + + """ + raise NotImplementedError() + def create_xid(self): """Create a two-phase transaction ID. diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index 58680f3564..d1009eca9e 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -229,6 +229,7 @@ class _EventKey(object): "No listeners found for event %s / %r / %s " % (self.target, self.identifier, self.fn) ) + dispatch_reg = _key_to_collection.pop(key) for collection_ref, listener_ref in dispatch_reg.items(): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 2fa9961eba..23cd778d04 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -971,7 +971,7 @@ class SQLCompiler(Compiled): @util.memoized_instancemethod def _get_set_input_sizes_lookup( - self, translate=None, include_types=None, exclude_types=None + self, include_types=None, exclude_types=None ): if not hasattr(self, "bind_names"): return None @@ -986,7 +986,7 @@ class SQLCompiler(Compiled): # for a dialect impl, also subclass Emulated first which overrides # this behavior in those cases to behave like the default. - if not include_types and not exclude_types: + if include_types is None and exclude_types is None: def _lookup_type(typ): dialect_impl = typ._unwrapped_dialect_impl(dialect) @@ -1001,12 +1001,12 @@ class SQLCompiler(Compiled): if ( dbtype is not None and ( - not exclude_types + exclude_types is None or dbtype not in exclude_types and type(dialect_impl) not in exclude_types ) and ( - not include_types + include_types is None or dbtype in include_types or type(dialect_impl) in include_types ) diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index a1d6d2725a..2ca6bdd7c0 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -3300,3 +3300,224 @@ class FutureExecuteTest(fixtures.FutureEngineMixin, fixtures.TablesTest): "'branching' of new connections.", connection.connect, ) + + +class SetInputSizesTest(fixtures.TablesTest): + __backend__ = True + + __requires__ = ("independent_connections",) + + @classmethod + def define_tables(cls, metadata): + Table( + "users", + metadata, + Column("user_id", INT, primary_key=True, autoincrement=False), + Column("user_name", VARCHAR(20)), + ) + + @testing.fixture + def input_sizes_fixture(self): + canary = mock.Mock() + + def do_set_input_sizes(cursor, list_of_tuples, context): + if not engine.dialect.positional: + # sort by "user_id", "user_name", or otherwise + # param name for a non-positional dialect, so that we can + # confirm the ordering. mostly a py2 thing probably can't + # occur on py3.6+ since we are passing dictionaries with + # "user_id", "user_name" + list_of_tuples = sorted( + list_of_tuples, key=lambda elem: elem[0] + ) + canary.do_set_input_sizes(cursor, list_of_tuples, context) + + def pre_exec(self): + self.translate_set_input_sizes = None + self.include_set_input_sizes = None + self.exclude_set_input_sizes = None + + engine = testing_engine() + engine.connect().close() + + # the idea of this test is we fully replace the dialect + # do_set_input_sizes with a mock, and we can then intercept + # the setting passed to the dialect. the test table uses very + # "safe" datatypes so that the DBAPI does not actually need + # setinputsizes() called in order to work. + + with mock.patch.object( + engine.dialect, "use_setinputsizes", True + ), mock.patch.object( + engine.dialect, "do_set_input_sizes", do_set_input_sizes + ), mock.patch.object( + engine.dialect.execution_ctx_cls, "pre_exec", pre_exec + ): + yield engine, canary + + def test_set_input_sizes_no_event(self, input_sizes_fixture): + engine, canary = input_sizes_fixture + + with engine.connect() as conn: + conn.execute( + self.tables.users.insert(), + [ + {"user_id": 1, "user_name": "n1"}, + {"user_id": 2, "user_name": "n2"}, + ], + ) + + eq_( + canary.mock_calls, + [ + call.do_set_input_sizes( + mock.ANY, + [ + ( + "user_id", + mock.ANY, + testing.eq_type_affinity(Integer), + ), + ( + "user_name", + mock.ANY, + testing.eq_type_affinity(String), + ), + ], + mock.ANY, + ) + ], + ) + + def test_set_input_sizes_expanding_param(self, input_sizes_fixture): + engine, canary = input_sizes_fixture + + with engine.connect() as conn: + conn.execute( + select(self.tables.users).where( + self.tables.users.c.user_name.in_(["x", "y", "z"]) + ) + ) + + eq_( + canary.mock_calls, + [ + call.do_set_input_sizes( + mock.ANY, + [ + ( + "user_name_1_1", + mock.ANY, + testing.eq_type_affinity(String), + ), + ( + "user_name_1_2", + mock.ANY, + testing.eq_type_affinity(String), + ), + ( + "user_name_1_3", + mock.ANY, + testing.eq_type_affinity(String), + ), + ], + mock.ANY, + ) + ], + ) + + @testing.requires.tuple_in + def test_set_input_sizes_expanding_tuple_param(self, input_sizes_fixture): + engine, canary = input_sizes_fixture + + from sqlalchemy import tuple_ + + with engine.connect() as conn: + conn.execute( + select(self.tables.users).where( + tuple_( + self.tables.users.c.user_id, + self.tables.users.c.user_name, + ).in_([(1, "x"), (2, "y")]) + ) + ) + + eq_( + canary.mock_calls, + [ + call.do_set_input_sizes( + mock.ANY, + [ + ( + "param_1_1_1", + mock.ANY, + testing.eq_type_affinity(Integer), + ), + ( + "param_1_1_2", + mock.ANY, + testing.eq_type_affinity(String), + ), + ( + "param_1_2_1", + mock.ANY, + testing.eq_type_affinity(Integer), + ), + ( + "param_1_2_2", + mock.ANY, + testing.eq_type_affinity(String), + ), + ], + mock.ANY, + ) + ], + ) + + def test_set_input_sizes_event(self, input_sizes_fixture): + engine, canary = input_sizes_fixture + + SPECIAL_STRING = mock.Mock() + + @event.listens_for(engine, "do_setinputsizes") + def do_setinputsizes( + inputsizes, cursor, statement, parameters, context + ): + for k in inputsizes: + if k.type._type_affinity is String: + inputsizes[k] = ( + SPECIAL_STRING, + None, + 0, + ) + + with engine.connect() as conn: + conn.execute( + self.tables.users.insert(), + [ + {"user_id": 1, "user_name": "n1"}, + {"user_id": 2, "user_name": "n2"}, + ], + ) + + eq_( + canary.mock_calls, + [ + call.do_set_input_sizes( + mock.ANY, + [ + ( + "user_id", + mock.ANY, + testing.eq_type_affinity(Integer), + ), + ( + "user_name", + (SPECIAL_STRING, None, 0), + testing.eq_type_affinity(String), + ), + ], + mock.ANY, + ) + ], + )