]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Revise setinputsizes approach
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 19 Jul 2020 21:39:14 +0000 (17:39 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 19 Jul 2020 21:45:41 +0000 (17:45 -0400)
in order to support asyncpg as well as pg8000,
we need to revise setinputsizes to work for more cases as well
as adjust NativeForEmulated a bit to work more completely with
the INTERVAL datatype.

- put most of the setinputsizes work into the compiler where
the computation can be cached.

- support per-element setinputsizes for a tuple

- adjust TypeDecorator so that _unwrapped_dialect_impl
will honor a type that the dialect links to directly in
it's adaption mapping.    Decouble _unwrapped_dialect_impl
from TypeDecorator._gen_dialect_impl() which has a different
purpose.   This allows setinputsizes to do the right thing
with the INTERVAL datatype.

- test cases for Oracle with Variant continue to work

Change-Id: I9e1ea33aeca3b92b365daa4a356d778191070c03

lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/type_api.py

index e567e11e798e2ef2c1a96f7bfc1185a2b406ee84..c431fa7555f3da2484ce82f4551a149e1648c4b9 100644 (file)
@@ -1406,37 +1406,19 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         currently cx_oracle.
 
         """
+        if self.isddl:
+            return None
 
-        if not hasattr(self.compiled, "bind_names"):
+        inputsizes = self.compiled._get_set_input_sizes_lookup(
+            translate=translate,
+            include_types=include_types,
+            exclude_types=exclude_types,
+        )
+        if inputsizes is None:
             return
 
-        inputsizes = {}
-        for bindparam in self.compiled.bind_names:
-            if bindparam in self.compiled.literal_execute_params:
-                continue
-
-            dialect_impl = bindparam.type._unwrapped_dialect_impl(self.dialect)
-            dialect_impl_cls = type(dialect_impl)
-            dbtype = dialect_impl.get_dbapi_type(self.dialect.dbapi)
-
-            if (
-                dbtype is not None
-                and (
-                    not exclude_types
-                    or dbtype not in exclude_types
-                    and dialect_impl_cls not in exclude_types
-                )
-                and (
-                    not include_types
-                    or dbtype in include_types
-                    or dialect_impl_cls in include_types
-                )
-            ):
-                inputsizes[bindparam] = dbtype
-            else:
-                inputsizes[bindparam] = None
-
         if self.dialect._has_events:
+            inputsizes = dict(inputsizes)
             self.dialect.dispatch.do_setinputsizes(
                 inputsizes, self.cursor, self.statement, self.parameters, self
             )
@@ -1445,14 +1427,29 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
             positional_inputsizes = []
             for key in self.compiled.positiontup:
                 bindparam = self.compiled.binds[key]
-                dbtype = inputsizes.get(bindparam, None)
-                if dbtype is not None:
-                    if key in self._expanded_parameters:
+                if bindparam in self.compiled.literal_execute_params:
+                    continue
+
+                if key in self._expanded_parameters:
+                    if bindparam._expanding_in_types:
+                        num = len(bindparam._expanding_in_types)
+                        dbtypes = inputsizes[bindparam]
                         positional_inputsizes.extend(
-                            [dbtype] * len(self._expanded_parameters[key])
+                            [
+                                dbtypes[idx % num]
+                                for idx, key in enumerate(
+                                    self._expanded_parameters[key]
+                                )
+                            ]
                         )
                     else:
-                        positional_inputsizes.append(dbtype)
+                        dbtype = inputsizes.get(bindparam, None)
+                        positional_inputsizes.extend(
+                            dbtype for dbtype in self._expanded_parameters[key]
+                        )
+                else:
+                    dbtype = inputsizes[bindparam]
+                    positional_inputsizes.append(dbtype)
             try:
                 self.cursor.setinputsizes(*positional_inputsizes)
             except BaseException as e:
@@ -1462,21 +1459,40 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         else:
             keyword_inputsizes = {}
             for bindparam, key in self.compiled.bind_names.items():
-                dbtype = inputsizes.get(bindparam, None)
-                if dbtype is not None:
-                    if translate:
-                        # TODO: this part won't work w/ the
-                        # expanded_parameters feature, e.g. for cx_oracle
-                        # quoted bound names
-                        key = translate.get(key, key)
-                    if not self.dialect.supports_unicode_binds:
-                        key = self.dialect._encoder(key)[0]
-                    if key in self._expanded_parameters:
+                if bindparam in self.compiled.literal_execute_params:
+                    continue
+
+                if key in self._expanded_parameters:
+                    if bindparam._expanding_in_types:
+                        num = len(bindparam._expanding_in_types)
+                        dbtypes = inputsizes[bindparam]
                         keyword_inputsizes.update(
-                            (expand_key, dbtype)
-                            for expand_key in self._expanded_parameters[key]
+                            [
+                                (key, dbtypes[idx % num])
+                                for idx, key in enumerate(
+                                    self._expanded_parameters[key]
+                                )
+                            ]
                         )
                     else:
+                        dbtype = inputsizes.get(bindparam, None)
+                        if dbtype is not None:
+                            keyword_inputsizes.update(
+                                (expand_key, dbtype)
+                                for expand_key in self._expanded_parameters[
+                                    key
+                                ]
+                            )
+                else:
+                    dbtype = inputsizes.get(bindparam, None)
+                    if dbtype is not None:
+                        if translate:
+                            # TODO: this part won't work w/ the
+                            # expanded_parameters feature, e.g. for cx_oracle
+                            # quoted bound names
+                            key = translate.get(key, key)
+                        if not self.dialect.supports_unicode_binds:
+                            key = self.dialect._encoder(key)[0]
                         keyword_inputsizes[key] = dbtype
             try:
                 self.cursor.setinputsizes(**keyword_inputsizes)
index 3a3ce5c45dab9e77e9026eba13a0c50708aeb670..61e26b003c0c29e9308fbb8954ecc7353a432c4b 100644 (file)
@@ -957,6 +957,68 @@ class SQLCompiler(Compiled):
                     pd[name] = value_param.value
             return pd
 
+    @util.memoized_instancemethod
+    def _get_set_input_sizes_lookup(
+        self, translate=None, include_types=None, exclude_types=None
+    ):
+        if not hasattr(self, "bind_names"):
+            return None
+
+        dialect = self.dialect
+        dbapi = self.dialect.dbapi
+
+        # _unwrapped_dialect_impl() is necessary so that we get the
+        # correct dialect type for a custom TypeDecorator, or a Variant,
+        # which is also a TypeDecorator.   Special types like Interval,
+        # that use TypeDecorator but also might be mapped directly
+        # for a dialect impl, also subclass Emulated first which overrides
+        # this behavior in those cases to behave like the default.
+
+        if not include_types and not exclude_types:
+
+            def _lookup_type(typ):
+                dialect_impl = typ._unwrapped_dialect_impl(dialect)
+                return dialect_impl.get_dbapi_type(dbapi)
+
+        else:
+
+            def _lookup_type(typ):
+                dialect_impl = typ._unwrapped_dialect_impl(dialect)
+                dbtype = dialect_impl.get_dbapi_type(dbapi)
+
+                if (
+                    dbtype is not None
+                    and (
+                        not exclude_types
+                        or dbtype not in exclude_types
+                        and type(dialect_impl) not in exclude_types
+                    )
+                    and (
+                        not include_types
+                        or dbtype in include_types
+                        or type(dialect_impl) in include_types
+                    )
+                ):
+                    return dbtype
+                else:
+                    return None
+
+        inputsizes = {}
+        literal_execute_params = self.literal_execute_params
+
+        for bindparam in self.bind_names:
+            if bindparam in literal_execute_params:
+                continue
+
+            if bindparam._expanding_in_types:
+                inputsizes[bindparam] = [
+                    _lookup_type(typ) for typ in bindparam._expanding_in_types
+                ]
+            else:
+                inputsizes[bindparam] = _lookup_type(bindparam.type)
+
+        return inputsizes
+
     @property
     def params(self):
         """Return the bind param dictionary embedded into this
@@ -4546,6 +4608,7 @@ class IdentifierPreparer(object):
 
         if name is None:
             name = table.name
+
         result = self.quote(name)
 
         effective_schema = self.schema_for_object(table)
index 83c7960ac8c7b343dd615d43e40ae928f37cd581..2d23c56e182c412181bb8e48b792b88bdc75aa27 100644 (file)
@@ -745,20 +745,20 @@ class Emulated(object):
 
     def adapt(self, impltype, **kw):
         if hasattr(impltype, "adapt_emulated_to_native"):
-
             if self.native:
                 # native support requested, dialect gave us a native
                 # implementor, pass control over to it
                 return impltype.adapt_emulated_to_native(self, **kw)
             else:
-                # impltype adapts to native, and we are not native,
-                # so reject the impltype in favor of "us"
-                impltype = self.__class__
-
-        if issubclass(impltype, self.__class__):
-            return self.adapt_to_emulated(impltype, **kw)
+                # non-native support, let the native implementor
+                # decide also, at the moment this is just to help debugging
+                # as only the default logic is implemented.
+                return impltype.adapt_native_to_emulated(self, **kw)
         else:
-            return super(Emulated, self).adapt(impltype, **kw)
+            if issubclass(impltype, self.__class__):
+                return self.adapt_to_emulated(impltype, **kw)
+            else:
+                return super(Emulated, self).adapt(impltype, **kw)
 
 
 class NativeForEmulated(object):
@@ -768,6 +768,16 @@ class NativeForEmulated(object):
 
     """
 
+    @classmethod
+    def adapt_native_to_emulated(cls, impl, **kw):
+        """Given an impl, adapt this type's class to the impl assuming
+        "emulated".
+
+
+        """
+        impltype = impl.__class__
+        return impl.adapt(impltype, **kw)
+
     @classmethod
     def adapt_emulated_to_native(cls, impl, **kw):
         """Given an impl, adapt this type's class to the impl assuming "native".
@@ -974,7 +984,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
         # otherwise adapt the impl type, link
         # to a copy of this TypeDecorator and return
         # that.
-        typedesc = self._unwrapped_dialect_impl(dialect)
+        typedesc = self.load_dialect_impl(dialect).dialect_impl(dialect)
         tt = self.copy()
         if not isinstance(tt, self.__class__):
             raise AssertionError(
@@ -1045,16 +1055,21 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
     def _unwrapped_dialect_impl(self, dialect):
         """Return the 'unwrapped' dialect impl for this type.
 
-        For a type that applies wrapping logic (e.g. TypeDecorator), give
-        us the real, actual dialect-level type that is used.
-
-        This is used by TypeDecorator itself as well at least one case where
-        dialects need to check that a particular specific dialect-level
-        type is in use, within the :meth:`.DefaultDialect.set_input_sizes`
+        This is used by the :meth:`.DefaultDialect.set_input_sizes`
         method.
 
         """
-        return self.load_dialect_impl(dialect).dialect_impl(dialect)
+
+        # some dialects have a lookup for a TypeDecorator subclass directly.
+        # postgresql.INTERVAL being the main example
+        typ = self.dialect_impl(dialect)
+
+        # if we are still a type decorator, load the per-dialect switch
+        # (such as what Variant uses), then get the dialect impl for that.
+        if isinstance(typ, self.__class__):
+            return typ.load_dialect_impl(dialect).dialect_impl(dialect)
+        else:
+            return typ
 
     def __getattr__(self, key):
         """Proxy all other undefined accessors to the underlying