]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Convert schema_translate to a post compile
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Mar 2020 18:52:05 +0000 (14:52 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Mar 2020 18:25:41 +0000 (14:25 -0400)
Revised the :paramref:`.Connection.execution_options.schema_translate_map`
feature such that the processing of the SQL statement to receive a specific
schema name occurs within the execution phase of the statement, rather than
at the compile phase.   This is to support the statement being efficiently
cached.   Previously, the current schema being rendered into the statement
for a particular run would be considered as part of the cache key itself,
meaning that for a run against hundreds of schemas, there would be hundreds
of cache keys, rendering the cache much less performant.  The new behavior
is that the rendering is done in a similar  manner as the "post compile"
rendering added in 1.4 as part of :ticket:`4645`, :ticket:`4808`.

Fixes: #5004
Change-Id: Ia5c89eb27cc8dc2c5b8e76d6c07c46290a7901b6

17 files changed:
doc/build/changelog/unreleased_14/5004.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/mock.py
lib/sqlalchemy/engine/reflection.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/assertsql.py
lib/sqlalchemy/testing/suite/test_reflection.py
test/dialect/postgresql/test_compiler.py
test/engine/test_execute.py
test/sql/test_compiler.py
test/sql/test_ddlemit.py

diff --git a/doc/build/changelog/unreleased_14/5004.rst b/doc/build/changelog/unreleased_14/5004.rst
new file mode 100644 (file)
index 0000000..a13a9a7
--- /dev/null
@@ -0,0 +1,14 @@
+.. change::
+    :tags: bug, engine
+    :tickets: 5004
+
+    Revised the :paramref:`.Connection.execution_options.schema_translate_map`
+    feature such that the processing of the SQL statement to receive a specific
+    schema name occurs within the execution phase of the statement, rather than
+    at the compile phase.   This is to support the statement being efficiently
+    cached.   Previously, the current schema being rendered into the statement
+    for a particular run would be considered as part of the cache key itself,
+    meaning that for a run against hundreds of schemas, there would be hundreds
+    of cache keys, rendering the cache much less performant.  The new behavior
+    is that the rendering is done in a similar  manner as the "post compile"
+    rendering added in 1.4 as part of :ticket:`4645`, :ticket:`4808`.
index a63ce0033f1cc1f1564c7a8c8377f12720e861a0..31425d4c0bc1b4c76e93610c18bf40690473bd72 100644 (file)
@@ -1631,6 +1631,9 @@ class SQLiteDialect(default.DefaultDialect):
         )
         return bool(info)
 
+    def _get_default_schema_name(self, connection):
+        return "main"
+
     @reflection.cache
     def get_view_names(self, connection, schema=None, **kw):
         if schema is not None:
index aa21fb13bb77ce28f425e295fa223447a3ed76fb..4ed3b9af7a98bc014100c1d97b307a9fa3884e3b 100644 (file)
@@ -17,7 +17,6 @@ from .. import inspection
 from .. import log
 from .. import util
 from ..sql import compiler
-from ..sql import schema
 from ..sql import util as sql_util
 
 
@@ -51,21 +50,7 @@ class Connection(Connectable):
 
     """
 
-    schema_for_object = schema._schema_getter(None)
-    """Return the ".schema" attribute for an object.
-
-    Used for :class:`.Table`, :class:`.Sequence` and similar objects,
-    and takes into account
-    the :paramref:`.Connection.execution_options.schema_translate_map`
-    parameter.
-
-      .. versionadded:: 1.1
-
-      .. seealso::
-
-          :ref:`schema_translating`
-
-    """
+    _schema_translate_map = None
 
     def __init__(
         self,
@@ -92,7 +77,7 @@ class Connection(Connectable):
             self.should_close_with_result = False
             self.dispatch = _dispatch
             self._has_events = _branch_from._has_events
-            self.schema_for_object = _branch_from.schema_for_object
+            self._schema_translate_map = _branch_from._schema_translate_map
         else:
             self.__connection = (
                 connection
@@ -122,6 +107,24 @@ class Connection(Connectable):
         if self._has_events or self.engine._has_events:
             self.dispatch.engine_connect(self, self.__branch)
 
+    def schema_for_object(self, obj):
+        """return the schema name for the given schema item taking into
+        account current schema translate map.
+
+        """
+
+        name = obj.schema
+        schema_translate_map = self._schema_translate_map
+
+        if (
+            schema_translate_map
+            and name in schema_translate_map
+            and obj._use_schema_map
+        ):
+            return schema_translate_map[name]
+        else:
+            return name
+
     def _branch(self):
         """Return a new Connection which references this Connection's
         engine and connection; but does not have close_with_result enabled,
@@ -1066,10 +1069,7 @@ class Connection(Connectable):
         dialect = self.dialect
 
         compiled = ddl.compile(
-            dialect=dialect,
-            schema_translate_map=self.schema_for_object
-            if not self.schema_for_object.is_default
-            else None,
+            dialect=dialect, schema_translate_map=self._schema_translate_map
         )
         ret = self._execute_context(
             dialect,
@@ -1103,7 +1103,7 @@ class Connection(Connectable):
                 dialect,
                 elem,
                 tuple(sorted(keys)),
-                self.schema_for_object.hash_key,
+                bool(self._schema_translate_map),
                 len(distilled_params) > 1,
             )
             compiled_sql = self._execution_options["compiled_cache"].get(key)
@@ -1112,9 +1112,7 @@ class Connection(Connectable):
                     dialect=dialect,
                     column_keys=keys,
                     inline=len(distilled_params) > 1,
-                    schema_translate_map=self.schema_for_object
-                    if not self.schema_for_object.is_default
-                    else None,
+                    schema_translate_map=self._schema_translate_map,
                     linting=self.dialect.compiler_linting
                     | compiler.WARN_LINTING,
                 )
@@ -1124,9 +1122,7 @@ class Connection(Connectable):
                 dialect=dialect,
                 column_keys=keys,
                 inline=len(distilled_params) > 1,
-                schema_translate_map=self.schema_for_object
-                if not self.schema_for_object.is_default
-                else None,
+                schema_translate_map=self._schema_translate_map,
                 linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
             )
 
@@ -1974,21 +1970,7 @@ class Engine(Connectable, log.Identified):
     _has_events = False
     _connection_cls = Connection
 
-    schema_for_object = schema._schema_getter(None)
-    """Return the ".schema" attribute for an object.
-
-    Used for :class:`.Table`, :class:`.Sequence` and similar objects,
-    and takes into account
-    the :paramref:`.Connection.execution_options.schema_translate_map`
-    parameter.
-
-      .. versionadded:: 1.1
-
-      .. seealso::
-
-          :ref:`schema_translating`
-
-    """
+    _schema_translate_map = None
 
     def __init__(
         self,
index b151b6e4830191c455b5d65c3f7c69d69821d355..d0940decf0c126968cbc05e5069aad71839b0df0 100644 (file)
@@ -28,7 +28,6 @@ from .. import types as sqltypes
 from .. import util
 from ..sql import compiler
 from ..sql import expression
-from ..sql import schema
 from ..sql.elements import quoted_name
 
 AUTOCOMMIT_REGEXP = re.compile(
@@ -129,6 +128,8 @@ class DefaultDialect(interfaces.Dialect):
 
     server_version_info = None
 
+    default_schema_name = None
+
     construct_arguments = None
     """Optional set of argument specifiers for various SQLAlchemy
     constructs, typically schema items.
@@ -495,20 +496,18 @@ class DefaultDialect(interfaces.Dialect):
                     self._set_connection_isolation(connection, isolation_level)
 
         if "schema_translate_map" in opts:
-            getter = schema._schema_getter(opts["schema_translate_map"])
-            engine.schema_for_object = getter
+            engine._schema_translate_map = map_ = opts["schema_translate_map"]
 
             @event.listens_for(engine, "engine_connect")
             def set_schema_translate_map(connection, branch):
-                connection.schema_for_object = getter
+                connection._schema_translate_map = map_
 
     def set_connection_execution_options(self, connection, opts):
         if "isolation_level" in opts:
             self._set_connection_isolation(connection, opts["isolation_level"])
 
         if "schema_translate_map" in opts:
-            getter = schema._schema_getter(opts["schema_translate_map"])
-            connection.schema_for_object = getter
+            connection._schema_translate_map = opts["schema_translate_map"]
 
     def _set_connection_isolation(self, connection, level):
         if connection.in_transaction():
@@ -701,11 +700,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
             self.execution_options = dict(self.execution_options)
             self.execution_options.update(connection._execution_options)
 
+        self.unicode_statement = util.text_type(compiled)
+        if compiled.schema_translate_map:
+            rst = compiled.preparer._render_schema_translates
+            self.unicode_statement = rst(
+                self.unicode_statement, connection._schema_translate_map
+            )
+
         if not dialect.supports_unicode_statements:
-            self.unicode_statement = util.text_type(compiled)
             self.statement = dialect._encoder(self.unicode_statement)[0]
         else:
-            self.statement = self.unicode_statement = util.text_type(compiled)
+            self.statement = self.unicode_statement
 
         self.cursor = self.create_cursor()
         self.compiled_parameters = []
@@ -807,6 +812,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         elif compiled.positional:
             positiontup = self.compiled.positiontup
 
+        if compiled.schema_translate_map:
+            rst = compiled.preparer._render_schema_translates
+            self.unicode_statement = rst(
+                self.unicode_statement, connection._schema_translate_map
+            )
+
         # final self.unicode_statement is now assigned, encode if needed
         # by dialect
         if not dialect.supports_unicode_statements:
index 570ee2d0430bbfd3a4cea2e16f810e9077edd413..bda9e91b5da44aa9eda550e04443390da3b4b91a 100644 (file)
@@ -11,7 +11,6 @@ from . import base
 from . import url as _url
 from .. import util
 from ..sql import ddl
-from ..sql import schema
 
 
 class MockConnection(base.Connectable):
@@ -23,7 +22,8 @@ class MockConnection(base.Connectable):
     dialect = property(attrgetter("_dialect"))
     name = property(lambda s: s._dialect.name)
 
-    schema_for_object = schema._schema_getter(None)
+    def schema_for_object(self, obj):
+        return obj.schema
 
     def connect(self, **kwargs):
         return self
index 203369ed88cac31b90b6cea979de323d91b1a795..8ef0d572f4a74be73dfa0c5759165d7e07b03afe 100644 (file)
@@ -701,7 +701,8 @@ class Inspector(object):
 
         dialect = self.bind.dialect
 
-        schema = self.bind.schema_for_object(table)
+        with self._operation_context() as conn:
+            schema = conn.schema_for_object(table)
 
         table_name = table.name
 
index 1f183b5c10f12ed4c43027ccb6e982eff2fd755c..ae9c3c73a410152866d36af9d8e1a37352d8a658 100644 (file)
@@ -26,6 +26,7 @@ To generate user-defined SQL strings, see
 import collections
 import contextlib
 import itertools
+import operator
 import re
 
 from . import base
@@ -39,6 +40,7 @@ from . import schema
 from . import selectable
 from . import sqltypes
 from .base import NO_ARG
+from .elements import quoted_name
 from .. import exc
 from .. import util
 
@@ -369,6 +371,8 @@ class Compiled(object):
 
     _cached_metadata = None
 
+    schema_translate_map = None
+
     execution_options = util.immutabledict()
     """
     Execution options propagated from the statement.   In some cases,
@@ -381,6 +385,7 @@ class Compiled(object):
         statement,
         bind=None,
         schema_translate_map=None,
+        render_schema_translate=False,
         compile_kwargs=util.immutabledict(),
     ):
         """Construct a new :class:`.Compiled` object.
@@ -411,6 +416,7 @@ class Compiled(object):
         self.bind = bind
         self.preparer = self.dialect.identifier_preparer
         if schema_translate_map:
+            self.schema_translate_map = schema_translate_map
             self.preparer = self.preparer._with_schema_translate(
                 schema_translate_map
             )
@@ -422,6 +428,11 @@ class Compiled(object):
                 self.execution_options = statement._execution_options
             self.string = self.process(self.statement, **compile_kwargs)
 
+            if render_schema_translate:
+                self.string = self.preparer._render_schema_translates(
+                    self.string, schema_translate_map
+                )
+
     @util.deprecated(
         "0.7",
         "The :meth:`.Compiled.compile` method is deprecated and will be "
@@ -3365,18 +3376,18 @@ class DDLCompiler(Compiled):
 
         return self.sql_compiler.post_process_text(ddl.statement % context)
 
-    def visit_create_schema(self, create):
+    def visit_create_schema(self, create, **kw):
         schema = self.preparer.format_schema(create.element)
         return "CREATE SCHEMA " + schema
 
-    def visit_drop_schema(self, drop):
+    def visit_drop_schema(self, drop, **kw):
         schema = self.preparer.format_schema(drop.element)
         text = "DROP SCHEMA " + schema
         if drop.cascade:
             text += " CASCADE"
         return text
 
-    def visit_create_table(self, create):
+    def visit_create_table(self, create, **kw):
         table = create.element
         preparer = self.preparer
 
@@ -3426,7 +3437,7 @@ class DDLCompiler(Compiled):
         text += "\n)%s\n\n" % self.post_create_table(table)
         return text
 
-    def visit_create_column(self, create, first_pk=False):
+    def visit_create_column(self, create, first_pk=False, **kw):
         column = create.element
 
         if column.system:
@@ -3442,7 +3453,7 @@ class DDLCompiler(Compiled):
         return text
 
     def create_table_constraints(
-        self, table, _include_foreign_key_constraints=None
+        self, table, _include_foreign_key_constraints=None, **kw
     ):
 
         # On some DB order is significant: visit PK first, then the
@@ -3482,10 +3493,10 @@ class DDLCompiler(Compiled):
             if p is not None
         )
 
-    def visit_drop_table(self, drop):
+    def visit_drop_table(self, drop, **kw):
         return "\nDROP TABLE " + self.preparer.format_table(drop.element)
 
-    def visit_drop_view(self, drop):
+    def visit_drop_view(self, drop, **kw):
         return "\nDROP VIEW " + self.preparer.format_table(drop.element)
 
     def _verify_index_table(self, index):
@@ -3495,7 +3506,7 @@ class DDLCompiler(Compiled):
             )
 
     def visit_create_index(
-        self, create, include_schema=False, include_table_schema=True
+        self, create, include_schema=False, include_table_schema=True, **kw
     ):
         index = create.element
         self._verify_index_table(index)
@@ -3521,7 +3532,7 @@ class DDLCompiler(Compiled):
         )
         return text
 
-    def visit_drop_index(self, drop):
+    def visit_drop_index(self, drop, **kw):
         index = drop.element
 
         if index.name is None:
@@ -3548,13 +3559,13 @@ class DDLCompiler(Compiled):
             index_name = schema_name + "." + index_name
         return index_name
 
-    def visit_add_constraint(self, create):
+    def visit_add_constraint(self, create, **kw):
         return "ALTER TABLE %s ADD %s" % (
             self.preparer.format_table(create.element.table),
             self.process(create.element),
         )
 
-    def visit_set_table_comment(self, create):
+    def visit_set_table_comment(self, create, **kw):
         return "COMMENT ON TABLE %s IS %s" % (
             self.preparer.format_table(create.element),
             self.sql_compiler.render_literal_value(
@@ -3562,12 +3573,12 @@ class DDLCompiler(Compiled):
             ),
         )
 
-    def visit_drop_table_comment(self, drop):
+    def visit_drop_table_comment(self, drop, **kw):
         return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
             drop.element
         )
 
-    def visit_set_column_comment(self, create):
+    def visit_set_column_comment(self, create, **kw):
         return "COMMENT ON COLUMN %s IS %s" % (
             self.preparer.format_column(
                 create.element, use_table=True, use_schema=True
@@ -3577,12 +3588,12 @@ class DDLCompiler(Compiled):
             ),
         )
 
-    def visit_drop_column_comment(self, drop):
+    def visit_drop_column_comment(self, drop, **kw):
         return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
             drop.element, use_table=True
         )
 
-    def visit_create_sequence(self, create):
+    def visit_create_sequence(self, create, **kw):
         text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(
             create.element
         )
@@ -3606,10 +3617,10 @@ class DDLCompiler(Compiled):
             text += " CYCLE"
         return text
 
-    def visit_drop_sequence(self, drop):
+    def visit_drop_sequence(self, drop, **kw):
         return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
 
-    def visit_drop_constraint(self, drop):
+    def visit_drop_constraint(self, drop, **kw):
         constraint = drop.element
         if constraint.name is not None:
             formatted_name = self.preparer.format_constraint(constraint)
@@ -3671,7 +3682,7 @@ class DDLCompiler(Compiled):
         else:
             return self.visit_check_constraint(constraint)
 
-    def visit_check_constraint(self, constraint):
+    def visit_check_constraint(self, constraint, **kw):
         text = ""
         if constraint.name is not None:
             formatted_name = self.preparer.format_constraint(constraint)
@@ -3683,7 +3694,7 @@ class DDLCompiler(Compiled):
         text += self.define_constraint_deferrability(constraint)
         return text
 
-    def visit_column_check_constraint(self, constraint):
+    def visit_column_check_constraint(self, constraint, **kw):
         text = ""
         if constraint.name is not None:
             formatted_name = self.preparer.format_constraint(constraint)
@@ -3695,7 +3706,7 @@ class DDLCompiler(Compiled):
         text += self.define_constraint_deferrability(constraint)
         return text
 
-    def visit_primary_key_constraint(self, constraint):
+    def visit_primary_key_constraint(self, constraint, **kw):
         if len(constraint) == 0:
             return ""
         text = ""
@@ -3715,7 +3726,7 @@ class DDLCompiler(Compiled):
         text += self.define_constraint_deferrability(constraint)
         return text
 
-    def visit_foreign_key_constraint(self, constraint):
+    def visit_foreign_key_constraint(self, constraint, **kw):
         preparer = self.preparer
         text = ""
         if constraint.name is not None:
@@ -3744,7 +3755,7 @@ class DDLCompiler(Compiled):
 
         return preparer.format_table(table)
 
-    def visit_unique_constraint(self, constraint):
+    def visit_unique_constraint(self, constraint, **kw):
         if len(constraint) == 0:
             return ""
         text = ""
@@ -3789,7 +3800,7 @@ class DDLCompiler(Compiled):
             text += " MATCH %s" % constraint.match
         return text
 
-    def visit_computed_column(self, generated):
+    def visit_computed_column(self, generated, **kw):
         text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
             generated.sqltext, include_table=False, literal_binds=True
         )
@@ -3975,7 +3986,16 @@ class IdentifierPreparer(object):
 
     illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
 
-    schema_for_object = schema._schema_getter(None)
+    schema_for_object = operator.attrgetter("schema")
+    """Return the .schema attribute for an object.
+
+    For the default IdentifierPreparer, the schema for an object is always
+    the value of the ".schema" attribute.   if the preparer is replaced
+    with one that has a non-empty schema_translate_map, the value of the
+    ".schema" attribute is rendered a symbol that will be converted to a
+    real schema name from the mapping post-compile.
+
+    """
 
     def __init__(
         self,
@@ -4016,9 +4036,39 @@ class IdentifierPreparer(object):
     def _with_schema_translate(self, schema_translate_map):
         prep = self.__class__.__new__(self.__class__)
         prep.__dict__.update(self.__dict__)
-        prep.schema_for_object = schema._schema_getter(schema_translate_map)
+
+        def symbol_getter(obj):
+            name = obj.schema
+            if name in schema_translate_map and obj._use_schema_map:
+                return quoted_name(
+                    "[SCHEMA_%s]" % (name or "_none"), quote=False
+                )
+            else:
+                return obj.schema
+
+        prep.schema_for_object = symbol_getter
         return prep
 
+    def _render_schema_translates(self, statement, schema_translate_map):
+        d = schema_translate_map
+        if None in d:
+            d["_none"] = d[None]
+
+        def replace(m):
+            name = m.group(2)
+            effective_schema = d[name]
+            if not effective_schema:
+                effective_schema = self.dialect.default_schema_name
+                if not effective_schema:
+                    # TODO: no coverage here
+                    raise exc.CompileError(
+                        "Dialect has no default schema name; can't "
+                        "use None as dynamic schema target."
+                    )
+            return self.quote(effective_schema)
+
+        return re.sub(r"(\[SCHEMA_([\w\d_]+)\])", replace, statement)
+
     def _escape_identifier(self, value):
         """Escape an identifier.
 
index 69f60ba246e6ef0bc46488d8528d077bf865fdda..02c14d75132bb8f58b4a8b3bb9949a807ea67ed1 100644 (file)
@@ -31,7 +31,6 @@ as components in SQL expressions.
 from __future__ import absolute_import
 
 import collections
-import operator
 
 import sqlalchemy
 from . import coercions
@@ -143,8 +142,7 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
         schema_item.dispatch._update(self.dispatch)
         return schema_item
 
-    def _translate_schema(self, effective_schema, map_):
-        return map_.get(effective_schema, effective_schema)
+    _use_schema_map = True
 
 
 class Table(DialectKWArgs, SchemaItem, TableClause):
@@ -4270,59 +4268,6 @@ class ThreadLocalMetaData(MetaData):
                 e.dispose()
 
 
-class _SchemaTranslateMap(object):
-    """Provide translation of schema names based on a mapping.
-
-    Also provides helpers for producing cache keys and optimized
-    access when no mapping is present.
-
-    Used by the :paramref:`.Connection.execution_options.schema_translate_map`
-    feature.
-
-    .. versionadded:: 1.1
-
-
-    """
-
-    __slots__ = "map_", "__call__", "hash_key", "is_default"
-
-    _default_schema_getter = operator.attrgetter("schema")
-
-    def __init__(self, map_):
-        self.map_ = map_
-        if map_ is not None:
-
-            def schema_for_object(obj):
-                effective_schema = self._default_schema_getter(obj)
-                effective_schema = obj._translate_schema(
-                    effective_schema, map_
-                )
-                return effective_schema
-
-            self.__call__ = schema_for_object
-            self.hash_key = ";".join(
-                "%s=%s" % (k, map_[k]) for k in sorted(map_, key=str)
-            )
-            self.is_default = False
-        else:
-            self.hash_key = 0
-            self.__call__ = self._default_schema_getter
-            self.is_default = True
-
-    @classmethod
-    def _schema_getter(cls, map_):
-        if map_ is None:
-            return _default_schema_map
-        elif isinstance(map_, _SchemaTranslateMap):
-            return map_
-        else:
-            return _SchemaTranslateMap(map_)
-
-
-_default_schema_map = _SchemaTranslateMap(None)
-_schema_getter = _SchemaTranslateMap._schema_getter
-
-
 class Computed(FetchedValue, SchemaItem):
     """Defines a generated column, i.e. "GENERATED ALWAYS AS" syntax.
 
index 45b9e7f9dc940dec3269959ac354e546598bb508..ab13b21c46f3ff14183452ccb7fcc7cb09f00c2d 100644 (file)
@@ -346,8 +346,7 @@ class FromClause(HasMemoized, roles.AnonymizedFromClauseRole, Selectable):
     _is_from_clause = True
     _is_join = False
 
-    def _translate_schema(self, effective_schema, map_):
-        return effective_schema
+    _use_schema_map = False
 
     _memoized_property = util.group_expirable_memoized_property(["_columns"])
 
index 3d69d1177258ca94198cd9bd57152408dad2f005..e106684bc123ebbc43c6141cceb76bf10d5630b3 100644 (file)
@@ -1000,6 +1000,8 @@ class SchemaType(SchemaEventTarget):
 
     """
 
+    _use_schema_map = True
+
     def __init__(
         self,
         name=None,
@@ -1030,9 +1032,6 @@ class SchemaType(SchemaEventTarget):
                 util.portable_instancemethod(self._on_metadata_drop),
             )
 
-    def _translate_schema(self, effective_schema, map_):
-        return map_.get(effective_schema, effective_schema)
-
     def _set_parent(self, column):
         column._on_table_attach(util.portable_instancemethod(self._set_table))
 
index e0bf4326e140d5dc355e21dadc060e465e495218..7dada1394b55e6c437b902c23de9dafac716a7e6 100644 (file)
@@ -352,6 +352,8 @@ class AssertsCompiledSQL(object):
         literal_binds=False,
         render_postcompile=False,
         schema_translate_map=None,
+        render_schema_translate=False,
+        default_schema_name=None,
         inline_flag=None,
     ):
         if use_default_dialect:
@@ -371,6 +373,9 @@ class AssertsCompiledSQL(object):
             elif isinstance(dialect, util.string_types):
                 dialect = url.URL(dialect).get_dialect()()
 
+        if default_schema_name:
+            dialect.default_schema_name = default_schema_name
+
         kw = {}
         compile_kwargs = {}
 
@@ -386,6 +391,9 @@ class AssertsCompiledSQL(object):
         if render_postcompile:
             compile_kwargs["render_postcompile"] = True
 
+        if render_schema_translate:
+            kw["render_schema_translate"] = True
+
         from sqlalchemy import orm
 
         if isinstance(clause, orm.Query):
index e38c7ddd89b8ee5cd6aeb142632a25ae405c6641..f0da694007e42c3ce11f9d68812b0a4fd6d85f5f 100644 (file)
@@ -91,21 +91,23 @@ class CompiledSQL(SQLMatchRule):
 
         context = execute_observed.context
         compare_dialect = self._compile_dialect(execute_observed)
+
+        if "schema_translate_map" in context.execution_options:
+            map_ = context.execution_options["schema_translate_map"]
+        else:
+            map_ = None
+
         if isinstance(context.compiled.statement, _DDLCompiles):
+
             compiled = context.compiled.statement.compile(
-                dialect=compare_dialect,
-                schema_translate_map=context.execution_options.get(
-                    "schema_translate_map"
-                ),
+                dialect=compare_dialect, schema_translate_map=map_
             )
         else:
             compiled = context.compiled.statement.compile(
                 dialect=compare_dialect,
                 column_keys=context.compiled.column_keys,
                 inline=context.compiled.inline,
-                schema_translate_map=context.execution_options.get(
-                    "schema_translate_map"
-                ),
+                schema_translate_map=map_,
             )
         _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
         parameters = execute_observed.parameters
index 473c981160e242691357c5a4920a05d2fec927dc..68a43feb7ff78752fe19a48af1134aab3560fd4b 100644 (file)
@@ -360,7 +360,6 @@ class ComponentReflectionTest(fixtures.TablesTest):
     @testing.requires.schema_reflection
     def test_dialect_initialize(self):
         engine = engines.testing_engine()
-        assert not hasattr(engine.dialect, "default_schema_name")
         inspect(engine)
         assert hasattr(engine.dialect, "default_schema_name")
 
index aabbc3ac3b0aebbeac79b85e4b1201e1702c410f..316f0c240bc72f0ad9217ceb0ca652b162372bba 100644 (file)
@@ -229,12 +229,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             postgresql.CreateEnumType(e1),
             "CREATE TYPE foo.somename AS ENUM ('x', 'y', 'z')",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
 
         self.assert_compile(
             postgresql.CreateEnumType(e2),
             "CREATE TYPE bar.somename AS ENUM ('x', 'y', 'z')",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
 
     def test_create_table_with_schema_type_schema_translate(self):
@@ -251,6 +253,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             CreateTable(table),
             "CREATE TABLE foo.some_table (q foo.somename, p bar.somename)",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
 
     def test_create_table_with_tablespace(self):
index 0b5b1b16d6eafb1d08ef80b1ab1ce8e9df0fd0d1..566cf065410786c5852c5462b2bdc26dbd6f3a97 100644 (file)
@@ -949,6 +949,13 @@ class CompiledCacheTest(fixtures.TestBase):
             conn.execute(ins, {"q": 2})
             eq_(conn.scalar(stmt), 2)
 
+        with config.db.connect().execution_options(
+            compiled_cache=cache, schema_translate_map={None: None},
+        ) as conn:
+            # should use default schema again even though statement
+            # was compiled with test_schema in the map
+            eq_(conn.scalar(stmt), 1)
+
         with config.db.connect().execution_options(
             compiled_cache=cache
         ) as conn:
@@ -1014,12 +1021,12 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
                 t1.drop(conn)
 
         asserter.assert_(
-            CompiledSQL("CREATE TABLE %s.t1 (x INTEGER)" % config.test_schema),
-            CompiledSQL("CREATE TABLE %s.t2 (x INTEGER)" % config.test_schema),
-            CompiledSQL("CREATE TABLE t3 (x INTEGER)"),
-            CompiledSQL("DROP TABLE t3"),
-            CompiledSQL("DROP TABLE %s.t2" % config.test_schema),
-            CompiledSQL("DROP TABLE %s.t1" % config.test_schema),
+            CompiledSQL("CREATE TABLE [SCHEMA__none].t1 (x INTEGER)"),
+            CompiledSQL("CREATE TABLE [SCHEMA_foo].t2 (x INTEGER)"),
+            CompiledSQL("CREATE TABLE [SCHEMA_bar].t3 (x INTEGER)"),
+            CompiledSQL("DROP TABLE [SCHEMA_bar].t3"),
+            CompiledSQL("DROP TABLE [SCHEMA_foo].t2"),
+            CompiledSQL("DROP TABLE [SCHEMA__none].t1"),
         )
 
     def _fixture(self):
@@ -1099,34 +1106,27 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
                 conn.execute(t3.delete())
 
         asserter.assert_(
+            CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"),
+            CompiledSQL("INSERT INTO [SCHEMA_foo].t2 (x) VALUES (:x)"),
+            CompiledSQL("INSERT INTO [SCHEMA_bar].t3 (x) VALUES (:x)"),
             CompiledSQL(
-                "INSERT INTO %s.t1 (x) VALUES (:x)" % config.test_schema
-            ),
-            CompiledSQL(
-                "INSERT INTO %s.t2 (x) VALUES (:x)" % config.test_schema
+                "UPDATE [SCHEMA__none].t1 SET x=:x WHERE "
+                "[SCHEMA__none].t1.x = :x_1"
             ),
-            CompiledSQL("INSERT INTO t3 (x) VALUES (:x)"),
             CompiledSQL(
-                "UPDATE %s.t1 SET x=:x WHERE %s.t1.x = :x_1"
-                % (config.test_schema, config.test_schema)
+                "UPDATE [SCHEMA_foo].t2 SET x=:x WHERE "
+                "[SCHEMA_foo].t2.x = :x_1"
             ),
             CompiledSQL(
-                "UPDATE %s.t2 SET x=:x WHERE %s.t2.x = :x_1"
-                % (config.test_schema, config.test_schema)
+                "UPDATE [SCHEMA_bar].t3 SET x=:x WHERE "
+                "[SCHEMA_bar].t3.x = :x_1"
             ),
-            CompiledSQL("UPDATE t3 SET x=:x WHERE t3.x = :x_1"),
-            CompiledSQL(
-                "SELECT %s.t1.x FROM %s.t1"
-                % (config.test_schema, config.test_schema)
-            ),
-            CompiledSQL(
-                "SELECT %s.t2.x FROM %s.t2"
-                % (config.test_schema, config.test_schema)
-            ),
-            CompiledSQL("SELECT t3.x FROM t3"),
-            CompiledSQL("DELETE FROM %s.t1" % config.test_schema),
-            CompiledSQL("DELETE FROM %s.t2" % config.test_schema),
-            CompiledSQL("DELETE FROM t3"),
+            CompiledSQL("SELECT [SCHEMA__none].t1.x FROM [SCHEMA__none].t1"),
+            CompiledSQL("SELECT [SCHEMA_foo].t2.x FROM [SCHEMA_foo].t2"),
+            CompiledSQL("SELECT [SCHEMA_bar].t3.x FROM [SCHEMA_bar].t3"),
+            CompiledSQL("DELETE FROM [SCHEMA__none].t1"),
+            CompiledSQL("DELETE FROM [SCHEMA_foo].t2"),
+            CompiledSQL("DELETE FROM [SCHEMA_bar].t3"),
         )
 
     @testing.provide_metadata
@@ -1147,10 +1147,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
             conn = eng.connect()
             conn.execute(select([t2.c.x]))
         asserter.assert_(
-            CompiledSQL(
-                "SELECT %s.t2.x FROM %s.t2"
-                % (config.test_schema, config.test_schema)
-            )
+            CompiledSQL("SELECT [SCHEMA_foo].t2.x FROM [SCHEMA_foo].t2")
         )
 
 
index 033da10a3aed63089a6debddcb211abe9a27db20..ef3e5d26e8597378806284b5a5e700ef27335ced 100644 (file)
@@ -3958,23 +3958,42 @@ class DDLTest(fixtures.TestBase, AssertsCompiledSQL):
 
         schema_translate_map = {None: "z", "bar": None, "foo": "bat"}
 
+        self.assert_compile(
+            schema.CreateTable(t1),
+            "CREATE TABLE [SCHEMA__none].t1 (q INTEGER)",
+            schema_translate_map=schema_translate_map,
+        )
         self.assert_compile(
             schema.CreateTable(t1),
             "CREATE TABLE z.t1 (q INTEGER)",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
 
+        self.assert_compile(
+            schema.CreateTable(t2),
+            "CREATE TABLE [SCHEMA_foo].t2 (q INTEGER)",
+            schema_translate_map=schema_translate_map,
+        )
         self.assert_compile(
             schema.CreateTable(t2),
             "CREATE TABLE bat.t2 (q INTEGER)",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
 
         self.assert_compile(
             schema.CreateTable(t3),
-            "CREATE TABLE t3 (q INTEGER)",
+            "CREATE TABLE [SCHEMA_bar].t3 (q INTEGER)",
             schema_translate_map=schema_translate_map,
         )
+        self.assert_compile(
+            schema.CreateTable(t3),
+            "CREATE TABLE main.t3 (q INTEGER)",
+            schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
+            default_schema_name="main",
+        )
 
     def test_schema_translate_map_sequence(self):
         s1 = schema.Sequence("s1")
@@ -3985,19 +4004,19 @@ class DDLTest(fixtures.TestBase, AssertsCompiledSQL):
 
         self.assert_compile(
             schema.CreateSequence(s1),
-            "CREATE SEQUENCE z.s1",
+            "CREATE SEQUENCE [SCHEMA__none].s1",
             schema_translate_map=schema_translate_map,
         )
 
         self.assert_compile(
             schema.CreateSequence(s2),
-            "CREATE SEQUENCE bat.s2",
+            "CREATE SEQUENCE [SCHEMA_foo].s2",
             schema_translate_map=schema_translate_map,
         )
 
         self.assert_compile(
             schema.CreateSequence(s3),
-            "CREATE SEQUENCE s3",
+            "CREATE SEQUENCE [SCHEMA_bar].s3",
             schema_translate_map=schema_translate_map,
         )
 
@@ -4135,6 +4154,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL):
             "bar.mytable.description FROM bar.mytable "
             "WHERE bar.mytable.name = :name_1",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
 
         self.assert_compile(
@@ -4143,6 +4163,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL):
             "foob.remotetable.value FROM foob.remotetable "
             "WHERE foob.remotetable.value = :value_1",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
 
         schema_translate_map = {"remote_owner": "foob"}
@@ -4155,6 +4176,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL):
             "foob.remotetable.value FROM mytable JOIN foob.remotetable "
             "ON mytable.myid = foob.remotetable.rem_id",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
 
     def test_schema_translate_aliases(self):
@@ -4183,6 +4205,18 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL):
             .where(alias.c.name == "foo")
         )
 
+        self.assert_compile(
+            stmt,
+            "SELECT [SCHEMA__none].myothertable.otherid, "
+            "[SCHEMA__none].myothertable.othername, "
+            "mytable_1.myid, mytable_1.name, mytable_1.description "
+            "FROM [SCHEMA__none].myothertable JOIN "
+            "[SCHEMA__none].mytable AS mytable_1 "
+            "ON [SCHEMA__none].myothertable.otherid = mytable_1.myid "
+            "WHERE mytable_1.name = :name_1",
+            schema_translate_map=schema_translate_map,
+        )
+
         self.assert_compile(
             stmt,
             "SELECT bar.myothertable.otherid, bar.myothertable.othername, "
@@ -4191,6 +4225,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL):
             "ON bar.myothertable.otherid = mytable_1.myid "
             "WHERE mytable_1.name = :name_1",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
 
     def test_schema_translate_crud(self):
@@ -4209,6 +4244,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL):
             table1.insert().values(description="foo"),
             "INSERT INTO bar.mytable (description) VALUES (:description)",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
 
         self.assert_compile(
@@ -4218,17 +4254,20 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL):
             "UPDATE bar.mytable SET description=:description "
             "WHERE bar.mytable.name = :name_1",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
         self.assert_compile(
             table1.delete().where(table1.c.name == "hi"),
             "DELETE FROM bar.mytable WHERE bar.mytable.name = :name_1",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
 
         self.assert_compile(
             table4.insert().values(value="there"),
             "INSERT INTO foob.remotetable (value) VALUES (:value)",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
 
         self.assert_compile(
@@ -4238,6 +4277,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL):
             "UPDATE foob.remotetable SET value=:value "
             "WHERE foob.remotetable.value = :value_1",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
 
         self.assert_compile(
@@ -4245,6 +4285,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL):
             "DELETE FROM foob.remotetable WHERE "
             "foob.remotetable.value = :value_1",
             schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
         )
 
     def test_alias(self):
index 13300f0b58e83add70d857a81cd8b9c14b2227c6..667891236391445671354bc561559418baacd16c 100644 (file)
@@ -28,7 +28,8 @@ class EmitDDLTest(fixtures.TestBase):
                 has_index=Mock(side_effect=has_index),
                 supports_comments=True,
                 inline_comments=False,
-            )
+            ),
+            _schema_translate_map=None,
         )
 
     def _mock_create_fixture(