From cadfc608d63f4e0df46c0daaa28902423fd88d71 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 23 Mar 2020 14:52:05 -0400 Subject: [PATCH] Convert schema_translate to a post compile Revised the :paramref:`.Connection.execution_options.schema_translate_map` feature such that the processing of the SQL statement to receive a specific schema name occurs within the execution phase of the statement, rather than at the compile phase. This is to support the statement being efficiently cached. Previously, the current schema being rendered into the statement for a particular run would be considered as part of the cache key itself, meaning that for a run against hundreds of schemas, there would be hundreds of cache keys, rendering the cache much less performant. The new behavior is that the rendering is done in a similar manner as the "post compile" rendering added in 1.4 as part of :ticket:`4645`, :ticket:`4808`. Fixes: #5004 Change-Id: Ia5c89eb27cc8dc2c5b8e76d6c07c46290a7901b6 --- doc/build/changelog/unreleased_14/5004.rst | 14 +++ lib/sqlalchemy/dialects/sqlite/base.py | 3 + lib/sqlalchemy/engine/base.py | 68 +++++------- lib/sqlalchemy/engine/default.py | 27 +++-- lib/sqlalchemy/engine/mock.py | 4 +- lib/sqlalchemy/engine/reflection.py | 3 +- lib/sqlalchemy/sql/compiler.py | 100 +++++++++++++----- lib/sqlalchemy/sql/schema.py | 57 +--------- lib/sqlalchemy/sql/selectable.py | 3 +- lib/sqlalchemy/sql/sqltypes.py | 5 +- lib/sqlalchemy/testing/assertions.py | 8 ++ lib/sqlalchemy/testing/assertsql.py | 16 +-- .../testing/suite/test_reflection.py | 1 - test/dialect/postgresql/test_compiler.py | 3 + test/engine/test_execute.py | 61 +++++------ test/sql/test_compiler.py | 49 ++++++++- test/sql/test_ddlemit.py | 3 +- 17 files changed, 240 insertions(+), 185 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/5004.rst diff --git a/doc/build/changelog/unreleased_14/5004.rst b/doc/build/changelog/unreleased_14/5004.rst new file mode 100644 index 0000000000..a13a9a7d32 --- /dev/null +++ b/doc/build/changelog/unreleased_14/5004.rst @@ -0,0 +1,14 @@ +.. change:: + :tags: bug, engine + :tickets: 5004 + + Revised the :paramref:`.Connection.execution_options.schema_translate_map` + feature such that the processing of the SQL statement to receive a specific + schema name occurs within the execution phase of the statement, rather than + at the compile phase. This is to support the statement being efficiently + cached. Previously, the current schema being rendered into the statement + for a particular run would be considered as part of the cache key itself, + meaning that for a run against hundreds of schemas, there would be hundreds + of cache keys, rendering the cache much less performant. The new behavior + is that the rendering is done in a similar manner as the "post compile" + rendering added in 1.4 as part of :ticket:`4645`, :ticket:`4808`. diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index a63ce0033f..31425d4c0b 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1631,6 +1631,9 @@ class SQLiteDialect(default.DefaultDialect): ) return bool(info) + def _get_default_schema_name(self, connection): + return "main" + @reflection.cache def get_view_names(self, connection, schema=None, **kw): if schema is not None: diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index aa21fb13bb..4ed3b9af7a 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -17,7 +17,6 @@ from .. import inspection from .. import log from .. import util from ..sql import compiler -from ..sql import schema from ..sql import util as sql_util @@ -51,21 +50,7 @@ class Connection(Connectable): """ - schema_for_object = schema._schema_getter(None) - """Return the ".schema" attribute for an object. - - Used for :class:`.Table`, :class:`.Sequence` and similar objects, - and takes into account - the :paramref:`.Connection.execution_options.schema_translate_map` - parameter. - - .. versionadded:: 1.1 - - .. seealso:: - - :ref:`schema_translating` - - """ + _schema_translate_map = None def __init__( self, @@ -92,7 +77,7 @@ class Connection(Connectable): self.should_close_with_result = False self.dispatch = _dispatch self._has_events = _branch_from._has_events - self.schema_for_object = _branch_from.schema_for_object + self._schema_translate_map = _branch_from._schema_translate_map else: self.__connection = ( connection @@ -122,6 +107,24 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: self.dispatch.engine_connect(self, self.__branch) + def schema_for_object(self, obj): + """return the schema name for the given schema item taking into + account current schema translate map. + + """ + + name = obj.schema + schema_translate_map = self._schema_translate_map + + if ( + schema_translate_map + and name in schema_translate_map + and obj._use_schema_map + ): + return schema_translate_map[name] + else: + return name + def _branch(self): """Return a new Connection which references this Connection's engine and connection; but does not have close_with_result enabled, @@ -1066,10 +1069,7 @@ class Connection(Connectable): dialect = self.dialect compiled = ddl.compile( - dialect=dialect, - schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default - else None, + dialect=dialect, schema_translate_map=self._schema_translate_map ) ret = self._execute_context( dialect, @@ -1103,7 +1103,7 @@ class Connection(Connectable): dialect, elem, tuple(sorted(keys)), - self.schema_for_object.hash_key, + bool(self._schema_translate_map), len(distilled_params) > 1, ) compiled_sql = self._execution_options["compiled_cache"].get(key) @@ -1112,9 +1112,7 @@ class Connection(Connectable): dialect=dialect, column_keys=keys, inline=len(distilled_params) > 1, - schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default - else None, + schema_translate_map=self._schema_translate_map, linting=self.dialect.compiler_linting | compiler.WARN_LINTING, ) @@ -1124,9 +1122,7 @@ class Connection(Connectable): dialect=dialect, column_keys=keys, inline=len(distilled_params) > 1, - schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default - else None, + schema_translate_map=self._schema_translate_map, linting=self.dialect.compiler_linting | compiler.WARN_LINTING, ) @@ -1974,21 +1970,7 @@ class Engine(Connectable, log.Identified): _has_events = False _connection_cls = Connection - schema_for_object = schema._schema_getter(None) - """Return the ".schema" attribute for an object. - - Used for :class:`.Table`, :class:`.Sequence` and similar objects, - and takes into account - the :paramref:`.Connection.execution_options.schema_translate_map` - parameter. - - .. versionadded:: 1.1 - - .. seealso:: - - :ref:`schema_translating` - - """ + _schema_translate_map = None def __init__( self, diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index b151b6e483..d0940decf0 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -28,7 +28,6 @@ from .. import types as sqltypes from .. import util from ..sql import compiler from ..sql import expression -from ..sql import schema from ..sql.elements import quoted_name AUTOCOMMIT_REGEXP = re.compile( @@ -129,6 +128,8 @@ class DefaultDialect(interfaces.Dialect): server_version_info = None + default_schema_name = None + construct_arguments = None """Optional set of argument specifiers for various SQLAlchemy constructs, typically schema items. @@ -495,20 +496,18 @@ class DefaultDialect(interfaces.Dialect): self._set_connection_isolation(connection, isolation_level) if "schema_translate_map" in opts: - getter = schema._schema_getter(opts["schema_translate_map"]) - engine.schema_for_object = getter + engine._schema_translate_map = map_ = opts["schema_translate_map"] @event.listens_for(engine, "engine_connect") def set_schema_translate_map(connection, branch): - connection.schema_for_object = getter + connection._schema_translate_map = map_ def set_connection_execution_options(self, connection, opts): if "isolation_level" in opts: self._set_connection_isolation(connection, opts["isolation_level"]) if "schema_translate_map" in opts: - getter = schema._schema_getter(opts["schema_translate_map"]) - connection.schema_for_object = getter + connection._schema_translate_map = opts["schema_translate_map"] def _set_connection_isolation(self, connection, level): if connection.in_transaction(): @@ -701,11 +700,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.execution_options = dict(self.execution_options) self.execution_options.update(connection._execution_options) + self.unicode_statement = util.text_type(compiled) + if compiled.schema_translate_map: + rst = compiled.preparer._render_schema_translates + self.unicode_statement = rst( + self.unicode_statement, connection._schema_translate_map + ) + if not dialect.supports_unicode_statements: - self.unicode_statement = util.text_type(compiled) self.statement = dialect._encoder(self.unicode_statement)[0] else: - self.statement = self.unicode_statement = util.text_type(compiled) + self.statement = self.unicode_statement self.cursor = self.create_cursor() self.compiled_parameters = [] @@ -807,6 +812,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): elif compiled.positional: positiontup = self.compiled.positiontup + if compiled.schema_translate_map: + rst = compiled.preparer._render_schema_translates + self.unicode_statement = rst( + self.unicode_statement, connection._schema_translate_map + ) + # final self.unicode_statement is now assigned, encode if needed # by dialect if not dialect.supports_unicode_statements: diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py index 570ee2d043..bda9e91b5d 100644 --- a/lib/sqlalchemy/engine/mock.py +++ b/lib/sqlalchemy/engine/mock.py @@ -11,7 +11,6 @@ from . import base from . import url as _url from .. import util from ..sql import ddl -from ..sql import schema class MockConnection(base.Connectable): @@ -23,7 +22,8 @@ class MockConnection(base.Connectable): dialect = property(attrgetter("_dialect")) name = property(lambda s: s._dialect.name) - schema_for_object = schema._schema_getter(None) + def schema_for_object(self, obj): + return obj.schema def connect(self, **kwargs): return self diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 203369ed88..8ef0d572f4 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -701,7 +701,8 @@ class Inspector(object): dialect = self.bind.dialect - schema = self.bind.schema_for_object(table) + with self._operation_context() as conn: + schema = conn.schema_for_object(table) table_name = table.name diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 1f183b5c10..ae9c3c73a4 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -26,6 +26,7 @@ To generate user-defined SQL strings, see import collections import contextlib import itertools +import operator import re from . import base @@ -39,6 +40,7 @@ from . import schema from . import selectable from . import sqltypes from .base import NO_ARG +from .elements import quoted_name from .. import exc from .. import util @@ -369,6 +371,8 @@ class Compiled(object): _cached_metadata = None + schema_translate_map = None + execution_options = util.immutabledict() """ Execution options propagated from the statement. In some cases, @@ -381,6 +385,7 @@ class Compiled(object): statement, bind=None, schema_translate_map=None, + render_schema_translate=False, compile_kwargs=util.immutabledict(), ): """Construct a new :class:`.Compiled` object. @@ -411,6 +416,7 @@ class Compiled(object): self.bind = bind self.preparer = self.dialect.identifier_preparer if schema_translate_map: + self.schema_translate_map = schema_translate_map self.preparer = self.preparer._with_schema_translate( schema_translate_map ) @@ -422,6 +428,11 @@ class Compiled(object): self.execution_options = statement._execution_options self.string = self.process(self.statement, **compile_kwargs) + if render_schema_translate: + self.string = self.preparer._render_schema_translates( + self.string, schema_translate_map + ) + @util.deprecated( "0.7", "The :meth:`.Compiled.compile` method is deprecated and will be " @@ -3365,18 +3376,18 @@ class DDLCompiler(Compiled): return self.sql_compiler.post_process_text(ddl.statement % context) - def visit_create_schema(self, create): + def visit_create_schema(self, create, **kw): schema = self.preparer.format_schema(create.element) return "CREATE SCHEMA " + schema - def visit_drop_schema(self, drop): + def visit_drop_schema(self, drop, **kw): schema = self.preparer.format_schema(drop.element) text = "DROP SCHEMA " + schema if drop.cascade: text += " CASCADE" return text - def visit_create_table(self, create): + def visit_create_table(self, create, **kw): table = create.element preparer = self.preparer @@ -3426,7 +3437,7 @@ class DDLCompiler(Compiled): text += "\n)%s\n\n" % self.post_create_table(table) return text - def visit_create_column(self, create, first_pk=False): + def visit_create_column(self, create, first_pk=False, **kw): column = create.element if column.system: @@ -3442,7 +3453,7 @@ class DDLCompiler(Compiled): return text def create_table_constraints( - self, table, _include_foreign_key_constraints=None + self, table, _include_foreign_key_constraints=None, **kw ): # On some DB order is significant: visit PK first, then the @@ -3482,10 +3493,10 @@ class DDLCompiler(Compiled): if p is not None ) - def visit_drop_table(self, drop): + def visit_drop_table(self, drop, **kw): return "\nDROP TABLE " + self.preparer.format_table(drop.element) - def visit_drop_view(self, drop): + def visit_drop_view(self, drop, **kw): return "\nDROP VIEW " + self.preparer.format_table(drop.element) def _verify_index_table(self, index): @@ -3495,7 +3506,7 @@ class DDLCompiler(Compiled): ) def visit_create_index( - self, create, include_schema=False, include_table_schema=True + self, create, include_schema=False, include_table_schema=True, **kw ): index = create.element self._verify_index_table(index) @@ -3521,7 +3532,7 @@ class DDLCompiler(Compiled): ) return text - def visit_drop_index(self, drop): + def visit_drop_index(self, drop, **kw): index = drop.element if index.name is None: @@ -3548,13 +3559,13 @@ class DDLCompiler(Compiled): index_name = schema_name + "." + index_name return index_name - def visit_add_constraint(self, create): + def visit_add_constraint(self, create, **kw): return "ALTER TABLE %s ADD %s" % ( self.preparer.format_table(create.element.table), self.process(create.element), ) - def visit_set_table_comment(self, create): + def visit_set_table_comment(self, create, **kw): return "COMMENT ON TABLE %s IS %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( @@ -3562,12 +3573,12 @@ class DDLCompiler(Compiled): ), ) - def visit_drop_table_comment(self, drop): + def visit_drop_table_comment(self, drop, **kw): return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table( drop.element ) - def visit_set_column_comment(self, create): + def visit_set_column_comment(self, create, **kw): return "COMMENT ON COLUMN %s IS %s" % ( self.preparer.format_column( create.element, use_table=True, use_schema=True @@ -3577,12 +3588,12 @@ class DDLCompiler(Compiled): ), ) - def visit_drop_column_comment(self, drop): + def visit_drop_column_comment(self, drop, **kw): return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column( drop.element, use_table=True ) - def visit_create_sequence(self, create): + def visit_create_sequence(self, create, **kw): text = "CREATE SEQUENCE %s" % self.preparer.format_sequence( create.element ) @@ -3606,10 +3617,10 @@ class DDLCompiler(Compiled): text += " CYCLE" return text - def visit_drop_sequence(self, drop): + def visit_drop_sequence(self, drop, **kw): return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) - def visit_drop_constraint(self, drop): + def visit_drop_constraint(self, drop, **kw): constraint = drop.element if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -3671,7 +3682,7 @@ class DDLCompiler(Compiled): else: return self.visit_check_constraint(constraint) - def visit_check_constraint(self, constraint): + def visit_check_constraint(self, constraint, **kw): text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -3683,7 +3694,7 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def visit_column_check_constraint(self, constraint): + def visit_column_check_constraint(self, constraint, **kw): text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -3695,7 +3706,7 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def visit_primary_key_constraint(self, constraint): + def visit_primary_key_constraint(self, constraint, **kw): if len(constraint) == 0: return "" text = "" @@ -3715,7 +3726,7 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def visit_foreign_key_constraint(self, constraint): + def visit_foreign_key_constraint(self, constraint, **kw): preparer = self.preparer text = "" if constraint.name is not None: @@ -3744,7 +3755,7 @@ class DDLCompiler(Compiled): return preparer.format_table(table) - def visit_unique_constraint(self, constraint): + def visit_unique_constraint(self, constraint, **kw): if len(constraint) == 0: return "" text = "" @@ -3789,7 +3800,7 @@ class DDLCompiler(Compiled): text += " MATCH %s" % constraint.match return text - def visit_computed_column(self, generated): + def visit_computed_column(self, generated, **kw): text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( generated.sqltext, include_table=False, literal_binds=True ) @@ -3975,7 +3986,16 @@ class IdentifierPreparer(object): illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS - schema_for_object = schema._schema_getter(None) + schema_for_object = operator.attrgetter("schema") + """Return the .schema attribute for an object. + + For the default IdentifierPreparer, the schema for an object is always + the value of the ".schema" attribute. if the preparer is replaced + with one that has a non-empty schema_translate_map, the value of the + ".schema" attribute is rendered a symbol that will be converted to a + real schema name from the mapping post-compile. + + """ def __init__( self, @@ -4016,9 +4036,39 @@ class IdentifierPreparer(object): def _with_schema_translate(self, schema_translate_map): prep = self.__class__.__new__(self.__class__) prep.__dict__.update(self.__dict__) - prep.schema_for_object = schema._schema_getter(schema_translate_map) + + def symbol_getter(obj): + name = obj.schema + if name in schema_translate_map and obj._use_schema_map: + return quoted_name( + "[SCHEMA_%s]" % (name or "_none"), quote=False + ) + else: + return obj.schema + + prep.schema_for_object = symbol_getter return prep + def _render_schema_translates(self, statement, schema_translate_map): + d = schema_translate_map + if None in d: + d["_none"] = d[None] + + def replace(m): + name = m.group(2) + effective_schema = d[name] + if not effective_schema: + effective_schema = self.dialect.default_schema_name + if not effective_schema: + # TODO: no coverage here + raise exc.CompileError( + "Dialect has no default schema name; can't " + "use None as dynamic schema target." + ) + return self.quote(effective_schema) + + return re.sub(r"(\[SCHEMA_([\w\d_]+)\])", replace, statement) + def _escape_identifier(self, value): """Escape an identifier. diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 69f60ba246..02c14d7513 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -31,7 +31,6 @@ as components in SQL expressions. from __future__ import absolute_import import collections -import operator import sqlalchemy from . import coercions @@ -143,8 +142,7 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): schema_item.dispatch._update(self.dispatch) return schema_item - def _translate_schema(self, effective_schema, map_): - return map_.get(effective_schema, effective_schema) + _use_schema_map = True class Table(DialectKWArgs, SchemaItem, TableClause): @@ -4270,59 +4268,6 @@ class ThreadLocalMetaData(MetaData): e.dispose() -class _SchemaTranslateMap(object): - """Provide translation of schema names based on a mapping. - - Also provides helpers for producing cache keys and optimized - access when no mapping is present. - - Used by the :paramref:`.Connection.execution_options.schema_translate_map` - feature. - - .. versionadded:: 1.1 - - - """ - - __slots__ = "map_", "__call__", "hash_key", "is_default" - - _default_schema_getter = operator.attrgetter("schema") - - def __init__(self, map_): - self.map_ = map_ - if map_ is not None: - - def schema_for_object(obj): - effective_schema = self._default_schema_getter(obj) - effective_schema = obj._translate_schema( - effective_schema, map_ - ) - return effective_schema - - self.__call__ = schema_for_object - self.hash_key = ";".join( - "%s=%s" % (k, map_[k]) for k in sorted(map_, key=str) - ) - self.is_default = False - else: - self.hash_key = 0 - self.__call__ = self._default_schema_getter - self.is_default = True - - @classmethod - def _schema_getter(cls, map_): - if map_ is None: - return _default_schema_map - elif isinstance(map_, _SchemaTranslateMap): - return map_ - else: - return _SchemaTranslateMap(map_) - - -_default_schema_map = _SchemaTranslateMap(None) -_schema_getter = _SchemaTranslateMap._schema_getter - - class Computed(FetchedValue, SchemaItem): """Defines a generated column, i.e. "GENERATED ALWAYS AS" syntax. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 45b9e7f9dc..ab13b21c46 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -346,8 +346,7 @@ class FromClause(HasMemoized, roles.AnonymizedFromClauseRole, Selectable): _is_from_clause = True _is_join = False - def _translate_schema(self, effective_schema, map_): - return effective_schema + _use_schema_map = False _memoized_property = util.group_expirable_memoized_property(["_columns"]) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 3d69d11772..e106684bc1 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1000,6 +1000,8 @@ class SchemaType(SchemaEventTarget): """ + _use_schema_map = True + def __init__( self, name=None, @@ -1030,9 +1032,6 @@ class SchemaType(SchemaEventTarget): util.portable_instancemethod(self._on_metadata_drop), ) - def _translate_schema(self, effective_schema, map_): - return map_.get(effective_schema, effective_schema) - def _set_parent(self, column): column._on_table_attach(util.portable_instancemethod(self._set_table)) diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index e0bf4326e1..7dada1394b 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -352,6 +352,8 @@ class AssertsCompiledSQL(object): literal_binds=False, render_postcompile=False, schema_translate_map=None, + render_schema_translate=False, + default_schema_name=None, inline_flag=None, ): if use_default_dialect: @@ -371,6 +373,9 @@ class AssertsCompiledSQL(object): elif isinstance(dialect, util.string_types): dialect = url.URL(dialect).get_dialect()() + if default_schema_name: + dialect.default_schema_name = default_schema_name + kw = {} compile_kwargs = {} @@ -386,6 +391,9 @@ class AssertsCompiledSQL(object): if render_postcompile: compile_kwargs["render_postcompile"] = True + if render_schema_translate: + kw["render_schema_translate"] = True + from sqlalchemy import orm if isinstance(clause, orm.Query): diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index e38c7ddd89..f0da694007 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -91,21 +91,23 @@ class CompiledSQL(SQLMatchRule): context = execute_observed.context compare_dialect = self._compile_dialect(execute_observed) + + if "schema_translate_map" in context.execution_options: + map_ = context.execution_options["schema_translate_map"] + else: + map_ = None + if isinstance(context.compiled.statement, _DDLCompiles): + compiled = context.compiled.statement.compile( - dialect=compare_dialect, - schema_translate_map=context.execution_options.get( - "schema_translate_map" - ), + dialect=compare_dialect, schema_translate_map=map_ ) else: compiled = context.compiled.statement.compile( dialect=compare_dialect, column_keys=context.compiled.column_keys, inline=context.compiled.inline, - schema_translate_map=context.execution_options.get( - "schema_translate_map" - ), + schema_translate_map=map_, ) _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled)) parameters = execute_observed.parameters diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 473c981160..68a43feb7f 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -360,7 +360,6 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.schema_reflection def test_dialect_initialize(self): engine = engines.testing_engine() - assert not hasattr(engine.dialect, "default_schema_name") inspect(engine) assert hasattr(engine.dialect, "default_schema_name") diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index aabbc3ac3b..316f0c240b 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -229,12 +229,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): postgresql.CreateEnumType(e1), "CREATE TYPE foo.somename AS ENUM ('x', 'y', 'z')", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) self.assert_compile( postgresql.CreateEnumType(e2), "CREATE TYPE bar.somename AS ENUM ('x', 'y', 'z')", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) def test_create_table_with_schema_type_schema_translate(self): @@ -251,6 +253,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): CreateTable(table), "CREATE TABLE foo.some_table (q foo.somename, p bar.somename)", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) def test_create_table_with_tablespace(self): diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 0b5b1b16d6..566cf06541 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -949,6 +949,13 @@ class CompiledCacheTest(fixtures.TestBase): conn.execute(ins, {"q": 2}) eq_(conn.scalar(stmt), 2) + with config.db.connect().execution_options( + compiled_cache=cache, schema_translate_map={None: None}, + ) as conn: + # should use default schema again even though statement + # was compiled with test_schema in the map + eq_(conn.scalar(stmt), 1) + with config.db.connect().execution_options( compiled_cache=cache ) as conn: @@ -1014,12 +1021,12 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): t1.drop(conn) asserter.assert_( - CompiledSQL("CREATE TABLE %s.t1 (x INTEGER)" % config.test_schema), - CompiledSQL("CREATE TABLE %s.t2 (x INTEGER)" % config.test_schema), - CompiledSQL("CREATE TABLE t3 (x INTEGER)"), - CompiledSQL("DROP TABLE t3"), - CompiledSQL("DROP TABLE %s.t2" % config.test_schema), - CompiledSQL("DROP TABLE %s.t1" % config.test_schema), + CompiledSQL("CREATE TABLE [SCHEMA__none].t1 (x INTEGER)"), + CompiledSQL("CREATE TABLE [SCHEMA_foo].t2 (x INTEGER)"), + CompiledSQL("CREATE TABLE [SCHEMA_bar].t3 (x INTEGER)"), + CompiledSQL("DROP TABLE [SCHEMA_bar].t3"), + CompiledSQL("DROP TABLE [SCHEMA_foo].t2"), + CompiledSQL("DROP TABLE [SCHEMA__none].t1"), ) def _fixture(self): @@ -1099,34 +1106,27 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): conn.execute(t3.delete()) asserter.assert_( + CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"), + CompiledSQL("INSERT INTO [SCHEMA_foo].t2 (x) VALUES (:x)"), + CompiledSQL("INSERT INTO [SCHEMA_bar].t3 (x) VALUES (:x)"), CompiledSQL( - "INSERT INTO %s.t1 (x) VALUES (:x)" % config.test_schema - ), - CompiledSQL( - "INSERT INTO %s.t2 (x) VALUES (:x)" % config.test_schema + "UPDATE [SCHEMA__none].t1 SET x=:x WHERE " + "[SCHEMA__none].t1.x = :x_1" ), - CompiledSQL("INSERT INTO t3 (x) VALUES (:x)"), CompiledSQL( - "UPDATE %s.t1 SET x=:x WHERE %s.t1.x = :x_1" - % (config.test_schema, config.test_schema) + "UPDATE [SCHEMA_foo].t2 SET x=:x WHERE " + "[SCHEMA_foo].t2.x = :x_1" ), CompiledSQL( - "UPDATE %s.t2 SET x=:x WHERE %s.t2.x = :x_1" - % (config.test_schema, config.test_schema) + "UPDATE [SCHEMA_bar].t3 SET x=:x WHERE " + "[SCHEMA_bar].t3.x = :x_1" ), - CompiledSQL("UPDATE t3 SET x=:x WHERE t3.x = :x_1"), - CompiledSQL( - "SELECT %s.t1.x FROM %s.t1" - % (config.test_schema, config.test_schema) - ), - CompiledSQL( - "SELECT %s.t2.x FROM %s.t2" - % (config.test_schema, config.test_schema) - ), - CompiledSQL("SELECT t3.x FROM t3"), - CompiledSQL("DELETE FROM %s.t1" % config.test_schema), - CompiledSQL("DELETE FROM %s.t2" % config.test_schema), - CompiledSQL("DELETE FROM t3"), + CompiledSQL("SELECT [SCHEMA__none].t1.x FROM [SCHEMA__none].t1"), + CompiledSQL("SELECT [SCHEMA_foo].t2.x FROM [SCHEMA_foo].t2"), + CompiledSQL("SELECT [SCHEMA_bar].t3.x FROM [SCHEMA_bar].t3"), + CompiledSQL("DELETE FROM [SCHEMA__none].t1"), + CompiledSQL("DELETE FROM [SCHEMA_foo].t2"), + CompiledSQL("DELETE FROM [SCHEMA_bar].t3"), ) @testing.provide_metadata @@ -1147,10 +1147,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): conn = eng.connect() conn.execute(select([t2.c.x])) asserter.assert_( - CompiledSQL( - "SELECT %s.t2.x FROM %s.t2" - % (config.test_schema, config.test_schema) - ) + CompiledSQL("SELECT [SCHEMA_foo].t2.x FROM [SCHEMA_foo].t2") ) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 033da10a3a..ef3e5d26e8 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -3958,23 +3958,42 @@ class DDLTest(fixtures.TestBase, AssertsCompiledSQL): schema_translate_map = {None: "z", "bar": None, "foo": "bat"} + self.assert_compile( + schema.CreateTable(t1), + "CREATE TABLE [SCHEMA__none].t1 (q INTEGER)", + schema_translate_map=schema_translate_map, + ) self.assert_compile( schema.CreateTable(t1), "CREATE TABLE z.t1 (q INTEGER)", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) + self.assert_compile( + schema.CreateTable(t2), + "CREATE TABLE [SCHEMA_foo].t2 (q INTEGER)", + schema_translate_map=schema_translate_map, + ) self.assert_compile( schema.CreateTable(t2), "CREATE TABLE bat.t2 (q INTEGER)", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) self.assert_compile( schema.CreateTable(t3), - "CREATE TABLE t3 (q INTEGER)", + "CREATE TABLE [SCHEMA_bar].t3 (q INTEGER)", schema_translate_map=schema_translate_map, ) + self.assert_compile( + schema.CreateTable(t3), + "CREATE TABLE main.t3 (q INTEGER)", + schema_translate_map=schema_translate_map, + render_schema_translate=True, + default_schema_name="main", + ) def test_schema_translate_map_sequence(self): s1 = schema.Sequence("s1") @@ -3985,19 +4004,19 @@ class DDLTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( schema.CreateSequence(s1), - "CREATE SEQUENCE z.s1", + "CREATE SEQUENCE [SCHEMA__none].s1", schema_translate_map=schema_translate_map, ) self.assert_compile( schema.CreateSequence(s2), - "CREATE SEQUENCE bat.s2", + "CREATE SEQUENCE [SCHEMA_foo].s2", schema_translate_map=schema_translate_map, ) self.assert_compile( schema.CreateSequence(s3), - "CREATE SEQUENCE s3", + "CREATE SEQUENCE [SCHEMA_bar].s3", schema_translate_map=schema_translate_map, ) @@ -4135,6 +4154,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): "bar.mytable.description FROM bar.mytable " "WHERE bar.mytable.name = :name_1", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) self.assert_compile( @@ -4143,6 +4163,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): "foob.remotetable.value FROM foob.remotetable " "WHERE foob.remotetable.value = :value_1", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) schema_translate_map = {"remote_owner": "foob"} @@ -4155,6 +4176,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): "foob.remotetable.value FROM mytable JOIN foob.remotetable " "ON mytable.myid = foob.remotetable.rem_id", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) def test_schema_translate_aliases(self): @@ -4183,6 +4205,18 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): .where(alias.c.name == "foo") ) + self.assert_compile( + stmt, + "SELECT [SCHEMA__none].myothertable.otherid, " + "[SCHEMA__none].myothertable.othername, " + "mytable_1.myid, mytable_1.name, mytable_1.description " + "FROM [SCHEMA__none].myothertable JOIN " + "[SCHEMA__none].mytable AS mytable_1 " + "ON [SCHEMA__none].myothertable.otherid = mytable_1.myid " + "WHERE mytable_1.name = :name_1", + schema_translate_map=schema_translate_map, + ) + self.assert_compile( stmt, "SELECT bar.myothertable.otherid, bar.myothertable.othername, " @@ -4191,6 +4225,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): "ON bar.myothertable.otherid = mytable_1.myid " "WHERE mytable_1.name = :name_1", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) def test_schema_translate_crud(self): @@ -4209,6 +4244,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): table1.insert().values(description="foo"), "INSERT INTO bar.mytable (description) VALUES (:description)", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) self.assert_compile( @@ -4218,17 +4254,20 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): "UPDATE bar.mytable SET description=:description " "WHERE bar.mytable.name = :name_1", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) self.assert_compile( table1.delete().where(table1.c.name == "hi"), "DELETE FROM bar.mytable WHERE bar.mytable.name = :name_1", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) self.assert_compile( table4.insert().values(value="there"), "INSERT INTO foob.remotetable (value) VALUES (:value)", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) self.assert_compile( @@ -4238,6 +4277,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): "UPDATE foob.remotetable SET value=:value " "WHERE foob.remotetable.value = :value_1", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) self.assert_compile( @@ -4245,6 +4285,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): "DELETE FROM foob.remotetable WHERE " "foob.remotetable.value = :value_1", schema_translate_map=schema_translate_map, + render_schema_translate=True, ) def test_alias(self): diff --git a/test/sql/test_ddlemit.py b/test/sql/test_ddlemit.py index 13300f0b58..6678912363 100644 --- a/test/sql/test_ddlemit.py +++ b/test/sql/test_ddlemit.py @@ -28,7 +28,8 @@ class EmitDDLTest(fixtures.TestBase): has_index=Mock(side_effect=has_index), supports_comments=True, inline_comments=False, - ) + ), + _schema_translate_map=None, ) def _mock_create_fixture( -- 2.47.3