]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
propose emulated setinputsizes embedded in the compiler
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 22 Nov 2021 19:28:26 +0000 (14:28 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 23 Nov 2021 21:52:55 +0000 (16:52 -0500)
Add a new system so that PostgreSQL and other dialects have a
reliable way to add casts to bound parameters in SQL statements,
replacing previous use of setinputsizes() for PG dialects.

rationale:

1. psycopg3 will be using the same SQLAlchemy-side "setinputsizes"
   as asyncpg, so we will be seeing a lot more of this

2. the full rendering that SQLAlchemy's compilation is performing
   is in the engine log as well as error messages.   Without this,
   we introduce three levels of SQL rendering, the compiler, the
   hidden "setinputsizes" in SQLAlchemy, and then whatever the DBAPI
   driver does.  With this new approach, users reporting bugs etc.
   will be less confused that there are as many as two separate
   layers of "hidden rendering"; SQLAlchemy's rendering is again
   fully transparent

3. calling upon a setinputsizes() method for every statement execution
   is expensive.  this way, the work is done behind the caching layer

4. for "fast insertmany()", I also want there to be a fast approach
   towards setinputsizes.  As it was, we were going to be taking
   a SQL INSERT with thousands of bound parameter placeholders and
   running a whole second pass on it to apply typecasts.    this way,
   we will at least be able to build the SQL string once without a huge
   second pass over the whole string

5. psycopg2 can use this same system for its ARRAY casts

6. the general need for PostgreSQL to have lots of type casts
   is now mostly in the base PostgreSQL dialect and works independently
   of a DBAPI being present.   dependence on DBAPI symbols that aren't
   complete / consistent / hashable is removed

I was originally going to try to build this into bind_expression(),
but it was revealed this worked poorly with custom bind_expression()
as well as empty sets.   the current impl also doesn't need to
run a second expression pass over the POSTCOMPILE sections, which
came out better than I originally thought it would.

Change-Id: I363e6d593d059add7bcc6d1f6c3f91dd2e683c0c

25 files changed:
doc/build/changelog/unreleased_20/postgresql_binds.rst [new file with mode: 0644]
doc/build/core/internals.rst
lib/sqlalchemy/connectors/pyodbc.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/dialects/postgresql/array.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/pg8000.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/engine/__init__.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/testing/assertsql.py
lib/sqlalchemy/testing/suite/test_types.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_query.py
test/dialect/postgresql/test_types.py
test/engine/test_deprecations.py
test/engine/test_execute.py
test/orm/test_lazy_relations.py
test/requirements.py
test/sql/test_type_expressions.py

diff --git a/doc/build/changelog/unreleased_20/postgresql_binds.rst b/doc/build/changelog/unreleased_20/postgresql_binds.rst
new file mode 100644 (file)
index 0000000..15717e7
--- /dev/null
@@ -0,0 +1,30 @@
+.. change::
+    :tags: change, postgresql
+
+    In support of new PostgreSQL features including the psycopg3 dialect as
+    well as extended "fast insertmany" support, the system by which typing
+    information for bound parameters is passed to the PostgreSQL database has
+    been redesigned to use inline casts emitted by the SQL compiler, and is now
+    applied to all PostgreSQL dialects. This is in contrast to the previous
+    approach which would rely upon the DBAPI in use to render these casts
+    itself, which in cases such as that of pg8000 and the adapted asyncpg
+    driver, would use the pep-249 ``setinputsizes()`` method, or with the
+    psycopg2 driver would rely on the driver itself in most cases, with some
+    special exceptions made for ARRAY.
+
+    The new approach now has all PostgreSQL dialects rendering these casts as
+    needed using PostgreSQL double-colon style within the compiler, and the use
+    of ``setinputsizes()`` is removed for PostgreSQL dialects, as this was not
+    generally part of these DBAPIs in any case (pg8000 being the only
+    exception, which added the method at the request of SQLAlchemy developers).
+
+    Advantages to this approach include per-statement performance, as no second
+    pass over the compiled statement is required at execution time, better
+    support for all DBAPIs, as there is now one consistent system of applying
+    typing information, and improved transparency, as the SQL logging output,
+    as well as the string output of a compiled statement, will show these casts
+    present in the statement directly, whereas previously these casts were not
+    visible in logging output as they would occur after the statement were
+    logged.
+
+
index 074acc798d8b0f8493bd34cc35c20bb203d6a2af..aa760073a99ceccb54f0af9c2a2c454b848d6932 100644 (file)
@@ -7,6 +7,9 @@ Some key internal constructs are listed here.
 
 .. currentmodule:: sqlalchemy
 
+.. autoclass:: sqlalchemy.engine.BindTyping
+    :members:
+
 .. autoclass:: sqlalchemy.engine.Compiled
     :members:
 
index 411985b5dcbfd7e8f0a26f213778cd12064adc4e..fdaa8981af4f4174d6fcbba69ed045dbc7bb601a 100644 (file)
@@ -9,6 +9,7 @@ import re
 
 from . import Connector
 from .. import util
+from ..engine import interfaces
 
 
 class PyODBCConnector(Connector):
@@ -21,15 +22,14 @@ class PyODBCConnector(Connector):
     supports_native_decimal = True
     default_paramstyle = "named"
 
-    use_setinputsizes = False
-
     # for non-DSN connections, this *may* be used to
     # hold the desired driver name
     pyodbc_driver_name = None
 
     def __init__(self, use_setinputsizes=False, **kw):
         super(PyODBCConnector, self).__init__(**kw)
-        self.use_setinputsizes = use_setinputsizes
+        if use_setinputsizes:
+            self.bind_typing = interfaces.BindTyping.SETINPUTSIZES
 
     @classmethod
     def dbapi(cls):
@@ -160,8 +160,9 @@ class PyODBCConnector(Connector):
         # for types such as pyodbc.SQL_WLONGVARCHAR, which is the datatype
         # that ticket #5649 is targeting.
 
-        # NOTE: as of #6058, this won't be called if the use_setinputsizes flag
-        # is False, or if no types were specified in list_of_tuples
+        # NOTE: as of #6058, this won't be called if the use_setinputsizes
+        # parameter were not passed to the dialect, or if no types were
+        # specified in list_of_tuples
 
         cursor.setinputsizes(
             [
index 2cfcb0e5c39cbc67f6c672e70174227fd79a31c4..672cbd7d962cb7b2887c193ab70f2e7719314f58 100644 (file)
@@ -446,6 +446,7 @@ from ... import processors
 from ... import types as sqltypes
 from ... import util
 from ...engine import cursor as _cursor
+from ...engine import interfaces
 
 
 class _OracleInteger(sqltypes.Integer):
@@ -783,8 +784,6 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
 
         self._generate_cursor_outputtype_handler()
 
-        self.include_set_input_sizes = self.dialect._include_setinputsizes
-
     def post_exec(self):
         if self.compiled and self.out_parameters and self.compiled.returning:
             # create a fake cursor result from the out parameters. unlike
@@ -833,7 +832,7 @@ class OracleDialect_cx_oracle(OracleDialect):
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = True
 
-    use_setinputsizes = True
+    bind_typing = interfaces.BindTyping.SETINPUTSIZES
 
     driver = "cx_oracle"
 
@@ -909,7 +908,6 @@ class OracleDialect_cx_oracle(OracleDialect):
         cx_Oracle = self.dbapi
 
         if cx_Oracle is None:
-            self._include_setinputsizes = {}
             self.cx_oracle_ver = (0, 0, 0)
         else:
             self.cx_oracle_ver = self._parse_cx_oracle_ver(cx_Oracle.version)
@@ -925,7 +923,7 @@ class OracleDialect_cx_oracle(OracleDialect):
                 )
                 self._cursor_var_unicode_kwargs = util.immutabledict()
 
-            self._include_setinputsizes = {
+            self.include_set_input_sizes = {
                 cx_Oracle.DATETIME,
                 cx_Oracle.NCLOB,
                 cx_Oracle.CLOB,
@@ -935,9 +933,9 @@ class OracleDialect_cx_oracle(OracleDialect):
                 cx_Oracle.BLOB,
                 cx_Oracle.FIXED_CHAR,
                 cx_Oracle.TIMESTAMP,
-                _OracleInteger,
-                _OracleBINARY_FLOAT,
-                _OracleBINARY_DOUBLE,
+                int,  # _OracleInteger,
+                # _OracleBINARY_FLOAT, _OracleBINARY_DOUBLE,
+                cx_Oracle.NATIVE_FLOAT,
             }
 
             self._paramval = lambda value: value.getvalue()
index 0cb574dacf7b5f52f34f652c38ff132378ed011b..ff590c1b07e018fa6f8ce8916526386dbc494c0f 100644 (file)
@@ -330,9 +330,6 @@ class ARRAY(sqltypes.ARRAY):
             and self.item_type.native_enum
         )
 
-    def bind_expression(self, bindvalue):
-        return bindvalue
-
     def bind_processor(self, dialect):
         item_proc = self.item_type.dialect_impl(dialect).bind_processor(
             dialect
index fe1f9fd5ad13a95a84204927406036bf7e4aad09..4ac0971e50a87149c041bf63d2372283cf787e52 100644 (file)
@@ -134,32 +134,28 @@ except ImportError:
     _python_UUID = None
 
 
+class AsyncpgString(sqltypes.String):
+    render_bind_cast = True
+
+
 class AsyncpgTime(sqltypes.Time):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.TIME
+    render_bind_cast = True
 
 
 class AsyncpgDate(sqltypes.Date):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.DATE
+    render_bind_cast = True
 
 
 class AsyncpgDateTime(sqltypes.DateTime):
-    def get_dbapi_type(self, dbapi):
-        if self.timezone:
-            return dbapi.TIMESTAMP_W_TZ
-        else:
-            return dbapi.TIMESTAMP
+    render_bind_cast = True
 
 
 class AsyncpgBoolean(sqltypes.Boolean):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.BOOLEAN
+    render_bind_cast = True
 
 
 class AsyncPgInterval(INTERVAL):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.INTERVAL
+    render_bind_cast = True
 
     @classmethod
     def adapt_emulated_to_native(cls, interval, **kw):
@@ -168,49 +164,45 @@ class AsyncPgInterval(INTERVAL):
 
 
 class AsyncPgEnum(ENUM):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.ENUM
+    render_bind_cast = True
 
 
 class AsyncpgInteger(sqltypes.Integer):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.INTEGER
+    render_bind_cast = True
 
 
 class AsyncpgBigInteger(sqltypes.BigInteger):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.BIGINTEGER
+    render_bind_cast = True
 
 
 class AsyncpgJSON(json.JSON):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.JSON
+    render_bind_cast = True
 
     def result_processor(self, dialect, coltype):
         return None
 
 
 class AsyncpgJSONB(json.JSONB):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.JSONB
+    render_bind_cast = True
 
     def result_processor(self, dialect, coltype):
         return None
 
 
 class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType):
-    def get_dbapi_type(self, dbapi):
-        raise NotImplementedError("should not be here")
+    pass
 
 
 class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.INTEGER
+    __visit_name__ = "json_int_index"
+
+    render_bind_cast = True
 
 
 class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.STRING
+    __visit_name__ = "json_str_index"
+
+    render_bind_cast = True
 
 
 class AsyncpgJSONPathType(json.JSONPathType):
@@ -224,8 +216,7 @@ class AsyncpgJSONPathType(json.JSONPathType):
 
 
 class AsyncpgUUID(UUID):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.UUID
+    render_bind_cast = True
 
     def bind_processor(self, dialect):
         if not self.as_uuid and dialect.use_native_uuid:
@@ -249,8 +240,7 @@ class AsyncpgUUID(UUID):
 
 
 class AsyncpgNumeric(sqltypes.Numeric):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.NUMBER
+    render_bind_cast = True
 
     def bind_processor(self, dialect):
         return None
@@ -281,18 +271,16 @@ class AsyncpgNumeric(sqltypes.Numeric):
 
 
 class AsyncpgFloat(AsyncpgNumeric):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.FLOAT
+    __visit_name__ = "float"
+    render_bind_cast = True
 
 
 class AsyncpgREGCLASS(REGCLASS):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.STRING
+    render_bind_cast = True
 
 
 class AsyncpgOID(OID):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.INTEGER
+    render_bind_cast = True
 
 
 class PGExecutionContext_asyncpg(PGExecutionContext):
@@ -317,11 +305,6 @@ class PGExecutionContext_asyncpg(PGExecutionContext):
         if not self.compiled:
             return
 
-        # we have to exclude ENUM because "enum" not really a "type"
-        # we can cast to, it has to be the name of the type itself.
-        # for now we just omit it from casting
-        self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM}
-
     def create_server_side_cursor(self):
         return self._dbapi_connection.cursor(server_side=True)
 
@@ -367,15 +350,7 @@ class AsyncAdapt_asyncpg_cursor:
         self._adapt_connection._handle_exception(error)
 
     def _parameter_placeholders(self, params):
-        if not self._inputsizes:
-            return tuple("$%d" % idx for idx, _ in enumerate(params, 1))
-        else:
-            return tuple(
-                "$%d::%s" % (idx, typ) if typ else "$%d" % idx
-                for idx, typ in enumerate(
-                    (_pg_types.get(typ) for typ in self._inputsizes), 1
-                )
-            )
+        return tuple(f"${idx:d}" for idx, _ in enumerate(params, 1))
 
     async def _prepare_and_execute(self, operation, parameters):
         adapt_connection = self._adapt_connection
@@ -464,7 +439,7 @@ class AsyncAdapt_asyncpg_cursor:
         )
 
     def setinputsizes(self, *inputsizes):
-        self._inputsizes = inputsizes
+        raise NotImplementedError()
 
     def __iter__(self):
         while self._rows:
@@ -798,6 +773,12 @@ class AsyncAdapt_asyncpg_dbapi:
                 "all prepared caches in response to this exception)",
             )
 
+    # pep-249 datatype placeholders.  As of SQLAlchemy 2.0 these aren't
+    # used, however the test suite looks for these in a few cases.
+    STRING = util.symbol("STRING")
+    NUMBER = util.symbol("NUMBER")
+    DATETIME = util.symbol("DATETIME")
+
     @util.memoized_property
     def _asyncpg_error_translate(self):
         import asyncpg
@@ -814,50 +795,6 @@ class AsyncAdapt_asyncpg_dbapi:
     def Binary(self, value):
         return value
 
-    STRING = util.symbol("STRING")
-    TIMESTAMP = util.symbol("TIMESTAMP")
-    TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ")
-    TIME = util.symbol("TIME")
-    DATE = util.symbol("DATE")
-    INTERVAL = util.symbol("INTERVAL")
-    NUMBER = util.symbol("NUMBER")
-    FLOAT = util.symbol("FLOAT")
-    BOOLEAN = util.symbol("BOOLEAN")
-    INTEGER = util.symbol("INTEGER")
-    BIGINTEGER = util.symbol("BIGINTEGER")
-    BYTES = util.symbol("BYTES")
-    DECIMAL = util.symbol("DECIMAL")
-    JSON = util.symbol("JSON")
-    JSONB = util.symbol("JSONB")
-    ENUM = util.symbol("ENUM")
-    UUID = util.symbol("UUID")
-    BYTEA = util.symbol("BYTEA")
-
-    DATETIME = TIMESTAMP
-    BINARY = BYTEA
-
-
-_pg_types = {
-    AsyncAdapt_asyncpg_dbapi.STRING: "varchar",
-    AsyncAdapt_asyncpg_dbapi.TIMESTAMP: "timestamp",
-    AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone",
-    AsyncAdapt_asyncpg_dbapi.DATE: "date",
-    AsyncAdapt_asyncpg_dbapi.TIME: "time",
-    AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval",
-    AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric",
-    AsyncAdapt_asyncpg_dbapi.FLOAT: "float",
-    AsyncAdapt_asyncpg_dbapi.BOOLEAN: "bool",
-    AsyncAdapt_asyncpg_dbapi.INTEGER: "integer",
-    AsyncAdapt_asyncpg_dbapi.BIGINTEGER: "bigint",
-    AsyncAdapt_asyncpg_dbapi.BYTES: "bytes",
-    AsyncAdapt_asyncpg_dbapi.DECIMAL: "decimal",
-    AsyncAdapt_asyncpg_dbapi.JSON: "json",
-    AsyncAdapt_asyncpg_dbapi.JSONB: "jsonb",
-    AsyncAdapt_asyncpg_dbapi.ENUM: "enum",
-    AsyncAdapt_asyncpg_dbapi.UUID: "uuid",
-    AsyncAdapt_asyncpg_dbapi.BYTEA: "bytea",
-}
-
 
 class PGDialect_asyncpg(PGDialect):
     driver = "asyncpg"
@@ -865,19 +802,20 @@ class PGDialect_asyncpg(PGDialect):
 
     supports_server_side_cursors = True
 
+    render_bind_cast = True
+
     default_paramstyle = "format"
     supports_sane_multi_rowcount = False
     execution_ctx_cls = PGExecutionContext_asyncpg
     statement_compiler = PGCompiler_asyncpg
     preparer = PGIdentifierPreparer_asyncpg
 
-    use_setinputsizes = True
-
     use_native_uuid = True
 
     colspecs = util.update_copy(
         PGDialect.colspecs,
         {
+            sqltypes.String: AsyncpgString,
             sqltypes.Time: AsyncpgTime,
             sqltypes.Date: AsyncpgDate,
             sqltypes.DateTime: AsyncpgDateTime,
@@ -977,20 +915,6 @@ class PGDialect_asyncpg(PGDialect):
                 e, self.dbapi.InterfaceError
             ) and "connection is closed" in str(e)
 
-    def do_set_input_sizes(self, cursor, list_of_tuples, context):
-        if self.positional:
-            cursor.setinputsizes(
-                *[dbtype for key, dbtype, sqltype in list_of_tuples]
-            )
-        else:
-            cursor.setinputsizes(
-                **{
-                    key: dbtype
-                    for key, dbtype, sqltype in list_of_tuples
-                    if dbtype
-                }
-            )
-
     async def setup_asyncpg_json_codec(self, conn):
         """set up JSON codec for asyncpg.
 
index 583d9c2630820909f5d68549e05dae4f4add7005..800b289fb978c4f68ead55576a44ee2124e45b6d 100644 (file)
@@ -1388,6 +1388,7 @@ from ... import sql
 from ... import util
 from ...engine import characteristics
 from ...engine import default
+from ...engine import interfaces
 from ...engine import reflection
 from ...sql import coercions
 from ...sql import compiler
@@ -2041,16 +2042,6 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
             self.drop(bind=bind, checkfirst=checkfirst)
 
 
-class _ColonCast(elements.CompilerColumnElement):
-    __visit_name__ = "colon_cast"
-    __slots__ = ("type", "clause", "typeclause")
-
-    def __init__(self, expression, type_):
-        self.type = type_
-        self.clause = expression
-        self.typeclause = elements.TypeClause(type_)
-
-
 colspecs = {
     sqltypes.ARRAY: _array.ARRAY,
     sqltypes.Interval: INTERVAL,
@@ -2106,11 +2097,12 @@ ischema_names = {
 
 
 class PGCompiler(compiler.SQLCompiler):
-    def visit_colon_cast(self, element, **kw):
-        return "%s::%s" % (
-            element.clause._compiler_dispatch(self, **kw),
-            element.typeclause._compiler_dispatch(self, **kw),
-        )
+    def render_bind_cast(self, type_, dbapi_type, sqltext):
+        return f"""{sqltext}::{
+                self.dialect.type_compiler.process(
+                    dbapi_type, identifier_preparer=self.preparer
+                )
+            }"""
 
     def visit_array(self, element, **kw):
         return "ARRAY[%s]" % self.visit_clauselist(element, **kw)
@@ -2854,6 +2846,12 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
     def visit_TSTZRANGE(self, type_, **kw):
         return "TSTZRANGE"
 
+    def visit_json_int_index(self, type_, **kw):
+        return "INT"
+
+    def visit_json_str_index(self, type_, **kw):
+        return "TEXT"
+
     def visit_datetime(self, type_, **kw):
         return self.visit_TIMESTAMP(type_, **kw)
 
@@ -3121,6 +3119,8 @@ class PGDialect(default.DefaultDialect):
     max_identifier_length = 63
     supports_sane_rowcount = True
 
+    bind_typing = interfaces.BindTyping.RENDER_CASTS
+
     supports_native_enum = True
     supports_native_boolean = True
     supports_smallserial = True
index 324007e7ea12f0435abca3137fa94b81d3f29cc3..e849d049908077c51b99a9382688b94d8bc325f9 100644 (file)
@@ -94,7 +94,6 @@ import re
 from uuid import UUID as _python_UUID
 
 from .array import ARRAY as PGARRAY
-from .base import _ColonCast
 from .base import _DECIMAL_TYPES
 from .base import _FLOAT_TYPES
 from .base import _INT_TYPES
@@ -115,7 +114,13 @@ from ... import util
 from ...sql.elements import quoted_name
 
 
+class _PGString(sqltypes.String):
+    render_bind_cast = True
+
+
 class _PGNumeric(sqltypes.Numeric):
+    render_bind_cast = True
+
     def result_processor(self, dialect, coltype):
         if self.asdecimal:
             if coltype in _FLOAT_TYPES:
@@ -141,26 +146,29 @@ class _PGNumeric(sqltypes.Numeric):
                 )
 
 
+class _PGFloat(_PGNumeric):
+    __visit_name__ = "float"
+    render_bind_cast = True
+
+
 class _PGNumericNoBind(_PGNumeric):
     def bind_processor(self, dialect):
         return None
 
 
 class _PGJSON(JSON):
+    render_bind_cast = True
+
     def result_processor(self, dialect, coltype):
         return None
 
-    def get_dbapi_type(self, dbapi):
-        return dbapi.JSON
-
 
 class _PGJSONB(JSONB):
+    render_bind_cast = True
+
     def result_processor(self, dialect, coltype):
         return None
 
-    def get_dbapi_type(self, dbapi):
-        return dbapi.JSONB
-
 
 class _PGJSONIndexType(sqltypes.JSON.JSONIndexType):
     def get_dbapi_type(self, dbapi):
@@ -168,21 +176,26 @@ class _PGJSONIndexType(sqltypes.JSON.JSONIndexType):
 
 
 class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.INTEGER
+    __visit_name__ = "json_int_index"
+
+    render_bind_cast = True
 
 
 class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.STRING
+    __visit_name__ = "json_str_index"
+
+    render_bind_cast = True
 
 
 class _PGJSONPathType(JSONPathType):
-    def get_dbapi_type(self, dbapi):
-        return 1009
+    pass
+
+    # DBAPI type 1009
 
 
 class _PGUUID(UUID):
+    render_bind_cast = True
+
     def bind_processor(self, dialect):
         if not self.as_uuid:
 
@@ -210,6 +223,8 @@ class _PGEnum(ENUM):
 
 
 class _PGInterval(INTERVAL):
+    render_bind_cast = True
+
     def get_dbapi_type(self, dbapi):
         return dbapi.INTERVAL
 
@@ -219,48 +234,39 @@ class _PGInterval(INTERVAL):
 
 
 class _PGTimeStamp(sqltypes.DateTime):
-    def get_dbapi_type(self, dbapi):
-        if self.timezone:
-            # TIMESTAMPTZOID
-            return 1184
-        else:
-            # TIMESTAMPOID
-            return 1114
+    render_bind_cast = True
+
+
+class _PGDate(sqltypes.Date):
+    render_bind_cast = True
 
 
 class _PGTime(sqltypes.Time):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.TIME
+    render_bind_cast = True
 
 
 class _PGInteger(sqltypes.Integer):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.INTEGER
+    render_bind_cast = True
 
 
 class _PGSmallInteger(sqltypes.SmallInteger):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.INTEGER
+    render_bind_cast = True
 
 
 class _PGNullType(sqltypes.NullType):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.NULLTYPE
+    pass
 
 
 class _PGBigInteger(sqltypes.BigInteger):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.BIGINTEGER
+    render_bind_cast = True
 
 
 class _PGBoolean(sqltypes.Boolean):
-    def get_dbapi_type(self, dbapi):
-        return dbapi.BOOLEAN
+    render_bind_cast = True
 
 
 class _PGARRAY(PGARRAY):
-    def bind_expression(self, bindvalue):
-        return _ColonCast(bindvalue, self)
+    render_bind_cast = True
 
 
 _server_side_id = util.counter()
@@ -362,7 +368,7 @@ class PGDialect_pg8000(PGDialect):
     preparer = PGIdentifierPreparer_pg8000
     supports_server_side_cursors = True
 
-    use_setinputsizes = True
+    render_bind_cast = True
 
     # reversed as of pg8000 1.16.6.  1.16.5 and lower
     # are no longer compatible
@@ -372,8 +378,9 @@ class PGDialect_pg8000(PGDialect):
     colspecs = util.update_copy(
         PGDialect.colspecs,
         {
+            sqltypes.String: _PGString,
             sqltypes.Numeric: _PGNumericNoBind,
-            sqltypes.Float: _PGNumeric,
+            sqltypes.Float: _PGFloat,
             sqltypes.JSON: _PGJSON,
             sqltypes.Boolean: _PGBoolean,
             sqltypes.NullType: _PGNullType,
@@ -386,6 +393,8 @@ class PGDialect_pg8000(PGDialect):
             sqltypes.Interval: _PGInterval,
             INTERVAL: _PGInterval,
             sqltypes.DateTime: _PGTimeStamp,
+            sqltypes.DateTime: _PGTimeStamp,
+            sqltypes.Date: _PGDate,
             sqltypes.Time: _PGTime,
             sqltypes.Integer: _PGInteger,
             sqltypes.SmallInteger: _PGSmallInteger,
@@ -517,20 +526,6 @@ class PGDialect_pg8000(PGDialect):
         cursor.execute("COMMIT")
         cursor.close()
 
-    def do_set_input_sizes(self, cursor, list_of_tuples, context):
-        if self.positional:
-            cursor.setinputsizes(
-                *[dbtype for key, dbtype, sqltype in list_of_tuples]
-            )
-        else:
-            cursor.setinputsizes(
-                **{
-                    key: dbtype
-                    for key, dbtype, sqltype in list_of_tuples
-                    if dbtype
-                }
-            )
-
     def do_begin_twophase(self, connection, xid):
         connection.connection.tpc_begin((0, xid, ""))
 
index f62830a0d8ad55a51e2dabb884b000d87d5d3215..19c01d2088aac1c1262cd89d87a1a054506a5471 100644 (file)
@@ -449,7 +449,6 @@ import re
 from uuid import UUID as _python_UUID
 
 from .array import ARRAY as PGARRAY
-from .base import _ColonCast
 from .base import _DECIMAL_TYPES
 from .base import _FLOAT_TYPES
 from .base import _INT_TYPES
@@ -516,8 +515,7 @@ class _PGHStore(HSTORE):
 
 
 class _PGARRAY(PGARRAY):
-    def bind_expression(self, bindvalue):
-        return _ColonCast(bindvalue, self)
+    render_bind_cast = True
 
 
 class _PGJSON(JSON):
index ba57eee51d8b3f4a40484542556f39185ae58d78..5f4c5be47cd46507dc14be7c8844f5571d0da9d0 100644 (file)
@@ -33,6 +33,7 @@ from .cursor import CursorResult
 from .cursor import FullyBufferedResultProxy
 from .cursor import ResultProxy
 from .interfaces import AdaptedConnection
+from .interfaces import BindTyping
 from .interfaces import Compiled
 from .interfaces import Connectable
 from .interfaces import CreateEnginePlugin
index 389270e45927c448d2fcb12d0243183c9398a5fd..61ef29d4a43f77fa80a6b7dbfb61e1679d79371a 100644 (file)
@@ -9,6 +9,7 @@ from __future__ import with_statement
 import contextlib
 import sys
 
+from .interfaces import BindTyping
 from .interfaces import Connectable
 from .interfaces import ConnectionEventsTarget
 from .interfaces import ExceptionContext
@@ -1486,7 +1487,7 @@ class Connection(Connectable):
 
         context.pre_exec()
 
-        if dialect.use_setinputsizes:
+        if dialect.bind_typing is BindTyping.SETINPUTSIZES:
             context._set_input_sizes()
 
         cursor, statement, parameters = (
index 3af24d913520f980e7b9ff1b32848c3614b7df23..d36ed6e653d742435da326424c0daac841c32724 100644 (file)
@@ -52,9 +52,13 @@ class DefaultDialect(interfaces.Dialect):
     supports_alter = True
     supports_comments = False
     inline_comments = False
-    use_setinputsizes = False
     supports_statement_cache = True
 
+    bind_typing = interfaces.BindTyping.NONE
+
+    include_set_input_sizes = None
+    exclude_set_input_sizes = None
+
     # the first value we'd get for an autoincrement
     # column.
     default_sequence_base = 1
@@ -260,6 +264,15 @@ class DefaultDialect(interfaces.Dialect):
             else:
                 self.server_side_cursors = True
 
+        if getattr(self, "use_setinputsizes", False):
+            util.warn_deprecated(
+                "The dialect-level use_setinputsizes attribute is "
+                "deprecated.  Please use "
+                "bind_typing = BindTyping.SETINPUTSIZES",
+                "2.0",
+            )
+            self.bind_typing = interfaces.BindTyping.SETINPUTSIZES
+
         self.encoding = encoding
         self.positional = False
         self._ischema = None
@@ -287,6 +300,10 @@ class DefaultDialect(interfaces.Dialect):
         self.label_length = label_length
         self.compiler_linting = compiler_linting
 
+    @util.memoized_property
+    def _bind_typing_render_casts(self):
+        return self.bind_typing is interfaces.BindTyping.RENDER_CASTS
+
     def _ensure_has_table_connection(self, arg):
 
         if not isinstance(arg, Connection):
@@ -736,9 +753,6 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
     returned_default_rows = None
     execution_options = util.immutabledict()
 
-    include_set_input_sizes = None
-    exclude_set_input_sizes = None
-
     cursor_fetch_strategy = _cursor._DEFAULT_FETCH
 
     cache_stats = None
@@ -1373,8 +1387,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         style of ``setinputsizes()`` on the cursor, using DB-API types
         from the bind parameter's ``TypeEngine`` objects.
 
-        This method only called by those dialects which require it,
-        currently cx_oracle, asyncpg and pg8000.
+        This method only called by those dialects which set
+        the :attr:`.Dialect.bind_typing` attribute to
+        :attr:`.BindTyping.SETINPUTSIZES`.   cx_Oracle is the only DBAPI
+        that requires setinputsizes(), pyodbc offers it as an option.
+
+        Prior to SQLAlchemy 2.0, the setinputsizes() approach was also used
+        for pg8000 and asyncpg, which has been changed to inline rendering
+        of casts.
 
         """
         if self.isddl or self.is_text:
@@ -1382,10 +1402,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
 
         compiled = self.compiled
 
-        inputsizes = compiled._get_set_input_sizes_lookup(
-            include_types=self.include_set_input_sizes,
-            exclude_types=self.exclude_set_input_sizes,
-        )
+        inputsizes = compiled._get_set_input_sizes_lookup()
 
         if inputsizes is None:
             return
index 6772a27bda531b52b013056d60dcee353cef960a..251d01c5e21b41b20eebb7f7e9fe251816514f05 100644 (file)
@@ -7,10 +7,60 @@
 
 """Define core interfaces used by the engine system."""
 
+from enum import Enum
+
 from ..sql.compiler import Compiled  # noqa
 from ..sql.compiler import TypeCompiler  # noqa
 
 
+class BindTyping(Enum):
+    """Define different methods of passing typing information for
+    bound parameters in a statement to the database driver.
+
+    .. versionadded:: 2.0
+
+    """
+
+    NONE = 1
+    """No steps are taken to pass typing information to the database driver.
+
+    This is the default behavior for databases such as SQLite, MySQL / MariaDB,
+    SQL Server.
+
+    """
+
+    SETINPUTSIZES = 2
+    """Use the pep-249 setinputsizes method.
+
+    This is only implemented for DBAPIs that support this method and for which
+    the SQLAlchemy dialect has the appropriate infrastructure for that
+    dialect set up.   Current dialects include cx_Oracle as well as
+    optional support for SQL Server using pyodbc.
+
+    When using setinputsizes, dialects also have a means of only using the
+    method for certain datatypes using include/exclude lists.
+
+    When SETINPUTSIZES is used, the :meth:`.Dialect.do_set_input_sizes` method
+    is called for each statement executed which has bound parameters.
+
+    """
+
+    RENDER_CASTS = 3
+    """Render casts or other directives in the SQL string.
+
+    This method is used for all PostgreSQL dialects, including asyncpg,
+    pg8000, psycopg, psycopg2.   Dialects which implement this can choose
+    which kinds of datatypes are explicitly cast in SQL statements and which
+    aren't.
+
+    When RENDER_CASTS is used, the compiler will invoke the
+    :meth:`.SQLCompiler.render_bind_cast` method for each
+    :class:`.BindParameter` object whose dialect-level type sets the
+    :attr:`.TypeEngine.render_bind_cast` attribute.
+
+    """
+
+
 class Dialect:
     """Define the behavior of a specific database and DB-API combination.
 
@@ -156,6 +206,16 @@ class Dialect:
 
     """
 
+    bind_typing = BindTyping.NONE
+    """define a means of passing typing information to the database and/or
+    driver for bound parameters.
+
+    See :class:`.BindTyping` for values.
+
+    ..versionadded:: 2.0
+
+    """
+
     def create_connect_args(self, url):
         """Build DB-API compatible connection arguments.
 
@@ -587,7 +647,9 @@ class Dialect:
     def do_set_input_sizes(self, cursor, list_of_tuples, context):
         """invoke the cursor.setinputsizes() method with appropriate arguments
 
-        This hook is called if the dialect.use_inputsizes flag is set to True.
+        This hook is called if the :attr:`.Dialect.bind_typing` attribute is
+        set to the
+        :attr:`.BindTyping.SETINPUTSIZES` value.
         Parameter data is passed in a list of tuples (paramname, dbtype,
         sqltype), where ``paramname`` is the key of the parameter in the
         statement, ``dbtype`` is the DBAPI datatype and ``sqltype`` is the
@@ -595,6 +657,12 @@ class Dialect:
 
         .. versionadded:: 1.4
 
+        .. versionchanged:: 2.0  - setinputsizes mode is now enabled by
+           setting :attr:`.Dialect.bind_typing` to
+           :attr:`.BindTyping.SETINPUTSIZES`.  Dialects which accept
+           a ``use_setinputsizes`` parameter should set this value
+           appropriately.
+
 
         """
         raise NotImplementedError()
index 29aa57faab6e0c92a3859cca74d0ae84e2bc525a..710c62c59796c017acd5ff25a564c9b4f3c8eb38 100644 (file)
@@ -227,6 +227,7 @@ FUNCTIONS = {
     functions.grouping_sets: "GROUPING SETS",
 }
 
+
 EXTRACT_MAP = {
     "month": "month",
     "day": "day",
@@ -1036,57 +1037,28 @@ class SQLCompiler(Compiled):
             return pd
 
     @util.memoized_instancemethod
-    def _get_set_input_sizes_lookup(
-        self, include_types=None, exclude_types=None
-    ):
-        if not hasattr(self, "bind_names"):
-            return None
-
+    def _get_set_input_sizes_lookup(self):
         dialect = self.dialect
-        dbapi = self.dialect.dbapi
 
-        # _unwrapped_dialect_impl() is necessary so that we get the
-        # correct dialect type for a custom TypeDecorator, or a Variant,
-        # which is also a TypeDecorator.   Special types like Interval,
-        # that use TypeDecorator but also might be mapped directly
-        # for a dialect impl, also subclass Emulated first which overrides
-        # this behavior in those cases to behave like the default.
+        include_types = dialect.include_set_input_sizes
+        exclude_types = dialect.exclude_set_input_sizes
 
-        if include_types is None and exclude_types is None:
+        dbapi = dialect.dbapi
 
-            def _lookup_type(typ):
-                dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi)
-                return dbtype
+        def lookup_type(typ):
+            dbtype = typ._unwrapped_dialect_impl(dialect).get_dbapi_type(dbapi)
 
-        else:
-
-            def _lookup_type(typ):
-                # note we get dbtype from the possibly TypeDecorator-wrapped
-                # dialect_impl, but the dialect_impl itself that we use for
-                # include/exclude is the unwrapped version.
-
-                dialect_impl = typ._unwrapped_dialect_impl(dialect)
-
-                dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi)
-
-                if (
-                    dbtype is not None
-                    and (
-                        exclude_types is None
-                        or dbtype not in exclude_types
-                        and type(dialect_impl) not in exclude_types
-                    )
-                    and (
-                        include_types is None
-                        or dbtype in include_types
-                        or type(dialect_impl) in include_types
-                    )
-                ):
-                    return dbtype
-                else:
-                    return None
+            if (
+                dbtype is not None
+                and (exclude_types is None or dbtype not in exclude_types)
+                and (include_types is None or dbtype in include_types)
+            ):
+                return dbtype
+            else:
+                return None
 
         inputsizes = {}
+
         literal_execute_params = self.literal_execute_params
 
         for bindparam in self.bind_names:
@@ -1095,10 +1067,10 @@ class SQLCompiler(Compiled):
 
             if bindparam.type._is_tuple_type:
                 inputsizes[bindparam] = [
-                    _lookup_type(typ) for typ in bindparam.type.types
+                    lookup_type(typ) for typ in bindparam.type.types
                 ]
             else:
-                inputsizes[bindparam] = _lookup_type(bindparam.type)
+                inputsizes[bindparam] = lookup_type(bindparam.type)
 
         return inputsizes
 
@@ -2061,7 +2033,25 @@ class SQLCompiler(Compiled):
                 parameter, values
             )
 
-        typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+        dialect = self.dialect
+        typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect)
+
+        if (
+            self.dialect._bind_typing_render_casts
+            and typ_dialect_impl.render_bind_cast
+        ):
+
+            def _render_bindtemplate(name):
+                return self.render_bind_cast(
+                    parameter.type,
+                    typ_dialect_impl,
+                    self.bindtemplate % {"name": name},
+                )
+
+        else:
+
+            def _render_bindtemplate(name):
+                return self.bindtemplate % {"name": name}
 
         if not values:
             to_update = []
@@ -2088,14 +2078,16 @@ class SQLCompiler(Compiled):
                 for i, tuple_element in enumerate(values, 1)
                 for j, value in enumerate(tuple_element, 1)
             ]
+
             replacement_expression = (
-                "VALUES " if self.dialect.tuple_in_values else ""
+                "VALUES " if dialect.tuple_in_values else ""
             ) + ", ".join(
                 "(%s)"
                 % (
                     ", ".join(
-                        self.bindtemplate
-                        % {"name": to_update[i * len(tuple_element) + j][0]}
+                        _render_bindtemplate(
+                            to_update[i * len(tuple_element) + j][0]
+                        )
                         for j, value in enumerate(tuple_element)
                     )
                 )
@@ -2107,7 +2099,7 @@ class SQLCompiler(Compiled):
                 for i, value in enumerate(values, 1)
             ]
             replacement_expression = ", ".join(
-                self.bindtemplate % {"name": key} for key, value in to_update
+                _render_bindtemplate(key) for key, value in to_update
             )
 
         return to_update, replacement_expression
@@ -2376,6 +2368,7 @@ class SQLCompiler(Compiled):
                     m = re.match(
                         r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
                     )
+                    assert m, "unexpected format for expanding parameter"
                     wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
                         m.group(2),
                         m.group(1),
@@ -2463,13 +2456,18 @@ class SQLCompiler(Compiled):
             name,
             post_compile=post_compile,
             expanding=bindparam.expanding,
+            bindparam_type=bindparam.type,
             **kwargs
         )
 
         if bindparam.expanding:
             ret = "(%s)" % ret
+
         return ret
 
+    def render_bind_cast(self, type_, dbapi_type, sqltext):
+        raise NotImplementedError()
+
     def render_literal_bindparam(
         self, bindparam, render_literal_value=NO_ARG, **kw
     ):
@@ -2556,6 +2554,7 @@ class SQLCompiler(Compiled):
         post_compile=False,
         expanding=False,
         escaped_from=None,
+        bindparam_type=None,
         **kw
     ):
 
@@ -2583,8 +2582,18 @@ class SQLCompiler(Compiled):
             self.escaped_bind_names[escaped_from] = name
         if post_compile:
             return "__[POSTCOMPILE_%s]" % name
-        else:
-            return self.bindtemplate % {"name": name}
+
+        ret = self.bindtemplate % {"name": name}
+
+        if (
+            bindparam_type is not None
+            and self.dialect._bind_typing_render_casts
+        ):
+            type_impl = bindparam_type._unwrapped_dialect_impl(self.dialect)
+            if type_impl.render_bind_cast:
+                ret = self.render_bind_cast(bindparam_type, type_impl, ret)
+
+        return ret
 
     def visit_cte(
         self,
index 01763f266219891500c586613ce7c1602b6b2120..69c4a4a76ed6e967ed9d400f72281e2df7ea655e 100644 (file)
@@ -51,6 +51,21 @@ class TypeEngine(Traversible):
     _is_array = False
     _is_type_decorator = False
 
+    render_bind_cast = False
+    """Render bind casts for :attr:`.BindTyping.RENDER_CASTS` mode.
+
+    If True, this type (usually a dialect level impl type) signals
+    to the compiler that a cast should be rendered around a bound parameter
+    for this type.
+
+    .. versionadded:: 2.0
+
+    .. seealso::
+
+        :class:`.BindTyping`
+
+    """
+
     class Comparator(operators.ColumnOperators):
         """Base class for custom comparison operations defined at the
         type level.  See :attr:`.TypeEngine.comparator_factory`.
index 485a13f82f562dcfa283f18d8c93d5828d395130..6d1dac96f9ec2ec6e23ec03679b79d3e4a19748e 100644 (file)
@@ -267,6 +267,10 @@ class DialectSQL(CompiledSQL):
 
     def _dialect_adjusted_statement(self, paramstyle):
         stmt = re.sub(r"[\n\t]", "", self.statement)
+
+        # temporarily escape out PG double colons
+        stmt = stmt.replace("::", "!!")
+
         if paramstyle == "pyformat":
             stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
         else:
@@ -279,6 +283,10 @@ class DialectSQL(CompiledSQL):
             elif paramstyle == "numeric":
                 repl = None
             stmt = re.sub(r":([\w_]+)", repl, stmt)
+
+        # put them back
+        stmt = stmt.replace("!!", "::")
+
         return stmt
 
     def _compare_sql(self, execute_observed, received_statement):
index 4a5396ed82a4723e70fcdf2099f20e29db87f878..c1cbf1ec68fb9f2789309472143b03ec5f356083 100644 (file)
@@ -491,19 +491,15 @@ class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase):
             impl = String(50)
             cache_ok = True
 
-            def get_dbapi_type(self, dbapi):
-                return dbapi.NUMBER
-
             def column_expression(self, col):
                 return cast(col, Integer)
 
             def bind_expression(self, col):
-                return cast(col, String(50))
+                return cast(type_coerce(col, Integer), String(50))
 
         return StringAsInt()
 
     def test_special_type(self, metadata, connection, string_as_int):
-
         type_ = string_as_int
 
         t = Table("t", metadata, Column("x", type_))
index 7e91f0ebb459c8a907fffa3fc32d5de765d9bcd3..93513c39dbe1e0f7b3dd5a3351a79fd0fe8c52b2 100644 (file)
@@ -38,7 +38,6 @@ from sqlalchemy.dialects.postgresql import array_agg as pg_array_agg
 from sqlalchemy.dialects.postgresql import ExcludeConstraint
 from sqlalchemy.dialects.postgresql import insert
 from sqlalchemy.dialects.postgresql import TSRANGE
-from sqlalchemy.dialects.postgresql.base import _ColonCast
 from sqlalchemy.dialects.postgresql.base import PGDialect
 from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
 from sqlalchemy.orm import aliased
@@ -99,14 +98,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
 
     __dialect__ = postgresql.dialect()
 
-    def test_colon_cast_is_slots(self):
-
-        c1 = _ColonCast(column("q"), String(50))
-
-        assert not hasattr(c1, "__dict__")
-
-        self.assert_compile(c1, "q::VARCHAR(50)")
-
     def test_update_returning(self):
         dialect = postgresql.dialect()
         table1 = table(
index 04bce4e22ca4c0304e8fc0504791024e126aa6f9..b488b146cf70248a89f83a6753e870c9f2967cba 100644 (file)
@@ -149,6 +149,13 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
         metadata.create_all(connection)
         self._assert_data_noautoincrement(connection, table)
 
+    def _ints_and_strs_setinputsizes(self, connection):
+
+        return (
+            connection.dialect._bind_typing_render_casts
+            and String().dialect_impl(connection.dialect).render_bind_cast
+        )
+
     def _assert_data_autoincrement(self, connection, table):
         """
         invoked by:
@@ -190,31 +197,64 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
 
             conn.execute(table.insert().inline(), {"data": "d8"})
 
-        asserter.assert_(
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                {"id": 30, "data": "d1"},
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                {"id": 1, "data": "d2"},
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (data) VALUES (:data)",
-                [{"data": "d5"}, {"data": "d6"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                [{"id": 33, "data": "d7"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (data) VALUES (:data)", [{"data": "d8"}]
-            ),
-        )
+        if self._ints_and_strs_setinputsizes(connection):
+            asserter.assert_(
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    {"id": 30, "data": "d1"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    {"id": 1, "data": "d2"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data::VARCHAR(30))",
+                    [{"data": "d5"}, {"data": "d6"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    [{"id": 33, "data": "d7"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data::VARCHAR(30))",
+                    [{"data": "d8"}],
+                ),
+            )
+        else:
+            asserter.assert_(
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    {"id": 30, "data": "d1"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    {"id": 1, "data": "d2"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data)",
+                    [{"data": "d5"}, {"data": "d6"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    [{"id": 33, "data": "d7"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data)",
+                    [{"data": "d8"}],
+                ),
+            )
 
         eq_(
             conn.execute(table.select()).fetchall(),
@@ -255,31 +295,64 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             conn.execute(table.insert().inline(), {"id": 33, "data": "d7"})
             conn.execute(table.insert().inline(), {"data": "d8"})
 
-        asserter.assert_(
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                {"id": 30, "data": "d1"},
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                {"id": 5, "data": "d2"},
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (data) VALUES (:data)",
-                [{"data": "d5"}, {"data": "d6"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                [{"id": 33, "data": "d7"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (data) VALUES (:data)", [{"data": "d8"}]
-            ),
-        )
+        if self._ints_and_strs_setinputsizes(connection):
+            asserter.assert_(
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    {"id": 30, "data": "d1"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    {"id": 5, "data": "d2"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data::VARCHAR(30))",
+                    [{"data": "d5"}, {"data": "d6"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    [{"id": 33, "data": "d7"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data::VARCHAR(30))",
+                    [{"data": "d8"}],
+                ),
+            )
+        else:
+            asserter.assert_(
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    {"id": 30, "data": "d1"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    {"id": 5, "data": "d2"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data)",
+                    [{"data": "d5"}, {"data": "d6"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    [{"id": 33, "data": "d7"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data)",
+                    [{"data": "d8"}],
+                ),
+            )
 
         eq_(
             conn.execute(table.select()).fetchall(),
@@ -336,32 +409,66 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
 
             conn.execute(table.insert().inline(), {"data": "d8"})
 
-        asserter.assert_(
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                {"id": 30, "data": "d1"},
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (data) VALUES (:data) RETURNING "
-                "testtable.id",
-                {"data": "d2"},
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (data) VALUES (:data)",
-                [{"data": "d5"}, {"data": "d6"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                [{"id": 33, "data": "d7"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (data) VALUES (:data)", [{"data": "d8"}]
-            ),
-        )
+        if self._ints_and_strs_setinputsizes(connection):
+            asserter.assert_(
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    {"id": 30, "data": "d1"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES "
+                    "(:data::VARCHAR(30)) RETURNING "
+                    "testtable.id",
+                    {"data": "d2"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data::VARCHAR(30))",
+                    [{"data": "d5"}, {"data": "d6"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    [{"id": 33, "data": "d7"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data::VARCHAR(30))",
+                    [{"data": "d8"}],
+                ),
+            )
+        else:
+            asserter.assert_(
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    {"id": 30, "data": "d1"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data) RETURNING "
+                    "testtable.id",
+                    {"data": "d2"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data)",
+                    [{"data": "d5"}, {"data": "d6"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    [{"id": 33, "data": "d7"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data)",
+                    [{"data": "d8"}],
+                ),
+            )
 
         eq_(
             conn.execute(table.select()).fetchall(),
@@ -404,32 +511,66 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             conn.execute(table.insert().inline(), {"id": 33, "data": "d7"})
             conn.execute(table.insert().inline(), {"data": "d8"})
 
-        asserter.assert_(
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                {"id": 30, "data": "d1"},
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (data) VALUES (:data) RETURNING "
-                "testtable.id",
-                {"data": "d2"},
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (data) VALUES (:data)",
-                [{"data": "d5"}, {"data": "d6"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                [{"id": 33, "data": "d7"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (data) VALUES (:data)", [{"data": "d8"}]
-            ),
-        )
+        if self._ints_and_strs_setinputsizes(connection):
+            asserter.assert_(
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    {"id": 30, "data": "d1"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES "
+                    "(:data::VARCHAR(30)) RETURNING "
+                    "testtable.id",
+                    {"data": "d2"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data::VARCHAR(30))",
+                    [{"data": "d5"}, {"data": "d6"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    [{"id": 33, "data": "d7"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data::VARCHAR(30))",
+                    [{"data": "d8"}],
+                ),
+            )
+        else:
+            asserter.assert_(
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    {"id": 30, "data": "d1"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data) RETURNING "
+                    "testtable.id",
+                    {"data": "d2"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data)",
+                    [{"data": "d5"}, {"data": "d6"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    [{"id": 33, "data": "d7"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (data) VALUES (:data)",
+                    [{"data": "d8"}],
+                ),
+            )
 
         eq_(
             conn.execute(table.select()).fetchall(),
@@ -466,35 +607,70 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             conn.execute(table.insert().inline(), {"id": 33, "data": "d7"})
             conn.execute(table.insert().inline(), {"data": "d8"})
 
-        asserter.assert_(
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                {"id": 30, "data": "d1"},
-            ),
-            CursorSQL("select nextval('my_seq')", consume_statement=False),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                {"id": 1, "data": "d2"},
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (nextval('%s'), "
-                ":data)" % seqname,
-                [{"data": "d5"}, {"data": "d6"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                [{"id": 33, "data": "d7"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (nextval('%s'), "
-                ":data)" % seqname,
-                [{"data": "d8"}],
-            ),
-        )
+        if self._ints_and_strs_setinputsizes(connection):
+            asserter.assert_(
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    {"id": 30, "data": "d1"},
+                ),
+                CursorSQL("select nextval('my_seq')", consume_statement=False),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    {"id": 1, "data": "d2"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (nextval('%s'), "
+                    ":data::VARCHAR(30))" % seqname,
+                    [{"data": "d5"}, {"data": "d6"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    [{"id": 33, "data": "d7"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (nextval('%s'), "
+                    ":data::VARCHAR(30))" % seqname,
+                    [{"data": "d8"}],
+                ),
+            )
+        else:
+            asserter.assert_(
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    {"id": 30, "data": "d1"},
+                ),
+                CursorSQL("select nextval('my_seq')", consume_statement=False),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    {"id": 1, "data": "d2"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (nextval('%s'), "
+                    ":data)" % seqname,
+                    [{"data": "d5"}, {"data": "d6"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    [{"id": 33, "data": "d7"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (nextval('%s'), "
+                    ":data)" % seqname,
+                    [{"data": "d8"}],
+                ),
+            )
         eq_(
             conn.execute(table.select()).fetchall(),
             [
@@ -530,35 +706,70 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             conn.execute(table.insert().inline(), {"id": 33, "data": "d7"})
             conn.execute(table.insert().inline(), {"data": "d8"})
 
-        asserter.assert_(
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                {"id": 30, "data": "d1"},
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES "
-                "(nextval('my_seq'), :data) RETURNING testtable.id",
-                {"data": "d2"},
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (nextval('%s'), "
-                ":data)" % seqname,
-                [{"data": "d5"}, {"data": "d6"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
-                [{"id": 33, "data": "d7"}],
-            ),
-            DialectSQL(
-                "INSERT INTO testtable (id, data) VALUES (nextval('%s'), "
-                ":data)" % seqname,
-                [{"data": "d8"}],
-            ),
-        )
+        if self._ints_and_strs_setinputsizes(connection):
+            asserter.assert_(
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    {"id": 30, "data": "d1"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(nextval('my_seq'), :data::VARCHAR(30)) "
+                    "RETURNING testtable.id",
+                    {"data": "d2"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (nextval('%s'), "
+                    ":data::VARCHAR(30))" % seqname,
+                    [{"data": "d5"}, {"data": "d6"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(:id::INTEGER, :data::VARCHAR(30))",
+                    [{"id": 33, "data": "d7"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (nextval('%s'), "
+                    ":data::VARCHAR(30))" % seqname,
+                    [{"data": "d8"}],
+                ),
+            )
+        else:
+            asserter.assert_(
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    {"id": 30, "data": "d1"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES "
+                    "(nextval('my_seq'), :data) RETURNING testtable.id",
+                    {"data": "d2"},
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    [{"id": 31, "data": "d3"}, {"id": 32, "data": "d4"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (nextval('%s'), "
+                    ":data)" % seqname,
+                    [{"data": "d5"}, {"data": "d6"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                    [{"id": 33, "data": "d7"}],
+                ),
+                DialectSQL(
+                    "INSERT INTO testtable (id, data) VALUES (nextval('%s'), "
+                    ":data)" % seqname,
+                    [{"data": "d8"}],
+                ),
+            )
 
         eq_(
             connection.execute(table.select()).fetchall(),
@@ -758,7 +969,9 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
         matchtable = self.tables.matchtable
         self.assert_compile(
             matchtable.c.title.match("somstr"),
-            "matchtable.title @@ to_tsquery(%s)",
+            # note we assume current tested DBAPIs use emulated setinputsizes
+            # here, the cast is not strictly necessary
+            "matchtable.title @@ to_tsquery(%s::VARCHAR(200))",
         )
 
     def test_simple_match(self, connection):
index 4c0b91f93d11ff84af035e17bc750e6c8d4a8741..25858af9b9518498c12f86766f1a6f29ddbe3445 100644 (file)
@@ -1206,7 +1206,7 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
 
         self.assert_compile(
             expr,
-            "x IN (__[POSTCOMPILE_x_1~~~~REPL~~::myenum[]~~])",
+            "x IN (__[POSTCOMPILE_x_1])",
             dialect=postgresql.psycopg2.dialect(),
         )
 
@@ -1224,7 +1224,7 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
 
         self.assert_compile(
             expr,
-            "x IN (__[POSTCOMPILE_x_1~~~~REPL~~::VARCHAR(15)[]~~])",
+            "x IN (__[POSTCOMPILE_x_1])",
             dialect=postgresql.psycopg2.dialect(),
         )
 
index d0cbe1a77387cae3210a5599a8dc80c0fc2f9b3f..548504e18d0f899bd6c28a478518e5b873227f57 100644 (file)
@@ -17,9 +17,11 @@ from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import ThreadLocalMetaData
+from sqlalchemy.engine import BindTyping
 from sqlalchemy.engine import reflection
 from sqlalchemy.engine.base import Connection
 from sqlalchemy.engine.base import Engine
+from sqlalchemy.engine.default import DefaultDialect
 from sqlalchemy.engine.mock import MockConnection
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
@@ -311,6 +313,16 @@ class CreateEngineTest(fixtures.TestBase):
                 _initialize=False,
             )
 
+    def test_dialect_use_setinputsizes_attr(self):
+        class MyDialect(DefaultDialect):
+            use_setinputsizes = True
+
+        with testing.expect_deprecated(
+            "The dialect-level use_setinputsizes attribute is deprecated."
+        ):
+            md = MyDialect()
+        is_(md.bind_typing, BindTyping.SETINPUTSIZES)
+
 
 class HandleInvalidatedOnConnectTest(fixtures.TestBase):
     __requires__ = ("sqlite",)
index e1c86444d55ba2e0e3bc35270806034405688fde..23a8bff7c058edd65c4e1f9a658f607e2debde5c 100644 (file)
@@ -24,6 +24,7 @@ from sqlalchemy import text
 from sqlalchemy import TypeDecorator
 from sqlalchemy import util
 from sqlalchemy import VARCHAR
+from sqlalchemy.engine import BindTyping
 from sqlalchemy.engine import default
 from sqlalchemy.engine.base import Connection
 from sqlalchemy.engine.base import Engine
@@ -3654,7 +3655,7 @@ class SetInputSizesTest(fixtures.TablesTest):
         # setinputsizes() called in order to work.
 
         with mock.patch.object(
-            engine.dialect, "use_setinputsizes", True
+            engine.dialect, "bind_typing", BindTyping.SETINPUTSIZES
         ), mock.patch.object(
             engine.dialect, "do_set_input_sizes", do_set_input_sizes
         ), mock.patch.object(
index 7f22929868b238a38037d92b23005ed8270c9cce..cb83bb6f7808c9aed057ee4c0a11dcb03379fd89 100644 (file)
@@ -1582,7 +1582,7 @@ class TypeCoerceTest(fixtures.MappedTest, testing.AssertsExecutionResults):
             return sa.cast(col, Integer)
 
         def bind_expression(self, col):
-            return sa.cast(col, String(50))
+            return sa.cast(sa.type_coerce(col, Integer), String(50))
 
     @classmethod
     def define_tables(cls, metadata):
index 967547ab9c9ccf3d86d02562d16fe47d4432646f..eeb71323bc59cea676810940216ac674cba6f2f7 100644 (file)
@@ -6,6 +6,7 @@
 import sys
 
 from sqlalchemy import exc
+from sqlalchemy.sql import sqltypes
 from sqlalchemy.sql import text
 from sqlalchemy.testing import exclusions
 from sqlalchemy.testing.exclusions import against
@@ -213,6 +214,7 @@ class DefaultRequirements(SuiteRequirements):
                 "mariadb+pymysql",
                 "mariadb+cymysql",
                 "mariadb+mysqlconnector",
+                "postgresql+asyncpg",
                 "postgresql+pg8000",
             ]
         )
@@ -795,6 +797,24 @@ class DefaultRequirements(SuiteRequirements):
             ["oracle"], "oracle converts empty strings to a blank space"
         )
 
+    @property
+    def string_type_isnt_subtype(self):
+        """target dialect does not have a dialect-specific subtype for String.
+
+        This is used for a special type expression test which wants to
+        test the compiler with a subclass of String, where we don't want
+        the dialect changing that type when we grab the 'impl'.
+
+        """
+
+        def go(config):
+            return (
+                sqltypes.String().dialect_impl(config.db.dialect).__class__
+                is sqltypes.String
+            )
+
+        return only_if(go)
+
     @property
     def empty_inserts_executemany(self):
         # waiting on https://jira.mariadb.org/browse/CONPY-152
index cab4f637112345074daab6720a3bbd19c5191e66..70c8839e30a1f523f33d18af7e3ccf32a18fac1e 100644 (file)
@@ -9,6 +9,7 @@ from sqlalchemy import testing
 from sqlalchemy import TypeDecorator
 from sqlalchemy import union
 from sqlalchemy.sql import LABEL_STYLE_TABLENAME_PLUS_COL
+from sqlalchemy.sql.type_api import UserDefinedType
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
@@ -431,6 +432,8 @@ class RoundTripTestBase:
 
 
 class StringRoundTripTest(fixtures.TablesTest, RoundTripTestBase):
+    __requires__ = ("string_type_isnt_subtype",)
+
     @classmethod
     def define_tables(cls, metadata):
         class MyString(String):
@@ -448,6 +451,29 @@ class StringRoundTripTest(fixtures.TablesTest, RoundTripTestBase):
         )
 
 
+class UserDefinedTypeRoundTripTest(fixtures.TablesTest, RoundTripTestBase):
+    @classmethod
+    def define_tables(cls, metadata):
+        class MyString(UserDefinedType):
+            cache_ok = True
+
+            def get_col_spec(self, **kw):
+                return "VARCHAR(50)"
+
+            def bind_expression(self, bindvalue):
+                return func.lower(bindvalue)
+
+            def column_expression(self, col):
+                return func.upper(col)
+
+        Table(
+            "test_table",
+            metadata,
+            Column("x", String(50)),
+            Column("y", MyString()),
+        )
+
+
 class TypeDecRoundTripTest(fixtures.TablesTest, RoundTripTestBase):
     @classmethod
     def define_tables(cls, metadata):
@@ -474,7 +500,11 @@ class ReturningTest(fixtures.TablesTest):
 
     @classmethod
     def define_tables(cls, metadata):
-        class MyString(String):
+        class MyString(TypeDecorator):
+            impl = String
+
+            cache_ok = True
+
             def column_expression(self, col):
                 return func.lower(col)