From: Mike Bayer Date: Sat, 9 Jan 2016 03:11:09 +0000 (-0500) Subject: - Multi-tenancy schema translation for :class:`.Table` objects is added. X-Git-Tag: rel_1_1_0b1~84^2~53 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=89facbed8855d1443dbe37919ff0645aea640ed0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Multi-tenancy schema translation for :class:`.Table` objects is added. This supports the use case of an application that uses the same set of :class:`.Table` objects in many schemas, such as schema-per-user. A new execution option :paramref:`.Connection.execution_options.schema_translate_map` is added. fixes #2685 - latest tox doesn't like the {posargs} in the profile rerunner --- diff --git a/doc/build/changelog/changelog_11.rst b/doc/build/changelog/changelog_11.rst index 0c103a0d14..975badc361 100644 --- a/doc/build/changelog/changelog_11.rst +++ b/doc/build/changelog/changelog_11.rst @@ -21,6 +21,21 @@ .. changelog:: :version: 1.1.0b1 + .. change:: + :tags: feature, engine + :tickets: 2685 + + Multi-tenancy schema translation for :class:`.Table` objects is added. + This supports the use case of an application that uses the same set of + :class:`.Table` objects in many schemas, such as schema-per-user. + A new execution option + :paramref:`.Connection.execution_options.schema_translate_map` is + added. + + .. seealso:: + + :ref:`change_2685` + .. change:: :tags: feature, engine :tickets: 3536 diff --git a/doc/build/changelog/migration_11.rst b/doc/build/changelog/migration_11.rst index 70182091ce..b87e7207b5 100644 --- a/doc/build/changelog/migration_11.rst +++ b/doc/build/changelog/migration_11.rst @@ -748,6 +748,43 @@ can be done like any other type:: :ticket:`2919` +.. _change_2685: + +Multi-Tenancy Schema Translation for Table objects +-------------------------------------------------- + +To support the use case of an application that uses the same set of +:class:`.Table` objects in many schemas, such as schema-per-user, a new +execution option :paramref:`.Connection.execution_options.schema_translate_map` +is added. Using this mapping, a set of :class:`.Table` +objects can be made on a per-connection basis to refer to any set of schemas +instead of the :paramref:`.Table.schema` to which they were assigned. The +translation works for DDL and SQL generation, as well as with the ORM. + +For example, if the ``User`` class were assigned the schema "per_user":: + + class User(Base): + __tablename__ = 'user' + id = Column(Integer, primary_key=True) + + __table_args__ = {'schema': 'per_user'} + +On each request, the :class:`.Session` can be set up to refer to a +different schema each time:: + + session = Session() + session.connection(execution_options={ + "schema_translate_map": {"per_user": "account_one"}}) + + # will query from the ``account_one.user`` table + session.query(User).get(5) + +.. seealso:: + + :ref:`schema_translating` + +:ticket:`2685` + .. _change_3531: The type_coerce function is now a persistent SQL element diff --git a/doc/build/core/connections.rst b/doc/build/core/connections.rst index a41babd29f..709642ecff 100644 --- a/doc/build/core/connections.rst +++ b/doc/build/core/connections.rst @@ -368,6 +368,69 @@ the SQL statement. When the :class:`.ResultProxy` is closed, the underlying :class:`.Connection` is closed for us, resulting in the DBAPI connection being returned to the pool with transactional resources removed. +.. _schema_translating: + +Translation of Schema Names +=========================== + +To support multi-tenancy applications that distribute common sets of tables +into multiple schemas, the +:paramref:`.Connection.execution_options.schema_translate_map` +execution option may be used to repurpose a set of :class:`.Table` objects +to render under different schema names without any changes. + +Given a table:: + + user_table = Table( + 'user', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)) + ) + +The "schema" of this :class:`.Table` as defined by the +:paramref:`.Table.schema` attribute is ``None``. The +:paramref:`.Connection.execution_options.schema_translate_map` can specify +that all :class:`.Table` objects with a schema of ``None`` would instead +render the schema as ``user_schema_one``:: + + connection = engine.connect().execution_options( + schema_translate_map={None: "user_schema_one"}) + + result = connection.execute(user_table.select()) + +The above code will invoke SQL on the database of the form:: + + SELECT user_schema_one.user.id, user_schema_one.user.name FROM + user_schema.user + +That is, the schema name is substituted with our translated name. The +map can specify any number of target->destination schemas:: + + connection = engine.connect().execution_options( + schema_translate_map={ + None: "user_schema_one", # no schema name -> "user_schema_one" + "special": "special_schema", # schema="special" becomes "special_schema" + "public": None # Table objects with schema="public" will render with no schema + }) + +The :paramref:`.Connection.execution_options.schema_translate_map` parameter +affects all DDL and SQL constructs generated from the SQL expression language, +as derived from the :class:`.Table` or :class:`.Sequence` objects. +It does **not** impact literal string SQL used via the :func:`.expression.text` +construct nor via plain strings passed to :meth:`.Connection.execute`. + +The feature takes effect **only** in those cases where the name of the +schema is derived directly from that of a :class:`.Table` or :class:`.Sequence`; +it does not impact methods where a string schema name is passed directly. +By this pattern, it takes effect within the "can create" / "can drop" checks +performed by methods such as :meth:`.MetaData.create_all` or +:meth:`.MetaData.drop_all` are called, and it takes effect when +using table reflection given a :class:`.Table` object. However it does +**not** affect the operations present on the :class:`.Inspector` object, +as the schema name is passed to these methods explicitly. + +.. versionadded:: 1.1 + .. _engine_disposal: Engine Disposal diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 3f9fcb27fc..3b3d65155c 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1481,8 +1481,11 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): raise exc.CompileError("Postgresql ENUM type requires a name.") name = self.quote(type_.name) - if not self.omit_schema and use_schema and type_.schema is not None: - name = self.quote_schema(type_.schema) + "." + name + effective_schema = self._get_effective_schema(type_) + + if not self.omit_schema and use_schema and \ + effective_schema is not None: + name = self.quote_schema(effective_schema) + "." + name return name @@ -1575,10 +1578,15 @@ class PGExecutionContext(default.DefaultExecutionContext): name = "%s_%s_seq" % (tab, col) column._postgresql_seq_name = seq_name = name - sch = column.table.schema - if sch is not None: + if column.table is not None: + effective_schema = self.connection._get_effective_schema( + column.table) + else: + effective_schema = None + + if effective_schema is not None: exc = "select nextval('\"%s\".\"%s\"')" % \ - (sch, seq_name) + (effective_schema, seq_name) else: exc = "select nextval('\"%s\"')" % \ (seq_name, ) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 31e253eedd..88f53abcf5 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -44,6 +44,8 @@ class Connection(Connectable): """ + _schema_translate_map = None + def __init__(self, engine, connection=None, close_with_result=False, _branch_from=None, _execution_options=None, _dispatch=None, @@ -140,6 +142,13 @@ class Connection(Connectable): c.__dict__ = self.__dict__.copy() return c + def _get_effective_schema(self, table): + effective_schema = table.schema + if self._schema_translate_map: + effective_schema = self._schema_translate_map.get( + effective_schema, effective_schema) + return effective_schema + def __enter__(self): return self @@ -277,6 +286,19 @@ class Connection(Connectable): of many DBAPIs. The flag is currently understood only by the psycopg2 dialect. + :param schema_translate_map: Available on: Connection, Engine. + A dictionary mapping schema names to schema names, that will be + applied to the :paramref:`.Table.schema` element of each + :class:`.Table` encountered when SQL or DDL expression elements + are compiled into strings; the resulting schema name will be + converted based on presence in the map of the original name. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`schema_translating` + """ c = self._clone() c._execution_options = c._execution_options.union(opt) @@ -959,7 +981,9 @@ class Connection(Connectable): dialect = self.dialect - compiled = ddl.compile(dialect=dialect) + compiled = ddl.compile( + dialect=dialect, + schema_translate_map=self._schema_translate_map) ret = self._execute_context( dialect, dialect.execution_ctx_cls._init_ddl, @@ -990,17 +1014,27 @@ class Connection(Connectable): dialect = self.dialect if 'compiled_cache' in self._execution_options: - key = dialect, elem, tuple(sorted(keys)), len(distilled_params) > 1 + key = ( + dialect, elem, tuple(sorted(keys)), + tuple( + (k, self._schema_translate_map[k]) + for k in sorted(self._schema_translate_map) + ) if self._schema_translate_map else None, + len(distilled_params) > 1 + ) compiled_sql = self._execution_options['compiled_cache'].get(key) if compiled_sql is None: compiled_sql = elem.compile( dialect=dialect, column_keys=keys, - inline=len(distilled_params) > 1) + inline=len(distilled_params) > 1, + schema_translate_map=self._schema_translate_map + ) self._execution_options['compiled_cache'][key] = compiled_sql else: compiled_sql = elem.compile( dialect=dialect, column_keys=keys, - inline=len(distilled_params) > 1) + inline=len(distilled_params) > 1, + schema_translate_map=self._schema_translate_map) ret = self._execute_context( dialect, diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 87278c2bef..160fe545e5 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -398,10 +398,18 @@ class DefaultDialect(interfaces.Dialect): if not branch: self._set_connection_isolation(connection, isolation_level) + if 'schema_translate_map' in opts: + @event.listens_for(engine, "engine_connect") + def set_schema_translate_map(connection, branch): + connection._schema_translate_map = opts['schema_translate_map'] + def set_connection_execution_options(self, connection, opts): if 'isolation_level' in opts: self._set_connection_isolation(connection, opts['isolation_level']) + if 'schema_translate_map' in opts: + connection._schema_translate_map = opts['schema_translate_map'] + def _set_connection_isolation(self, connection, level): if connection.in_transaction(): util.warn( diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 59eed51ecc..dca99e1ce2 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -529,7 +529,8 @@ class Inspector(object): """ dialect = self.bind.dialect - schema = table.schema + schema = self.bind._get_effective_schema(table) + table_name = table.name # get table-level arguments that are specifically diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index 0d0414ed1d..cb3e6fa8a9 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -233,6 +233,9 @@ class MockEngineStrategy(EngineStrategy): dialect = property(attrgetter('_dialect')) name = property(lambda s: s._dialect.name) + def _get_effective_schema(self, table): + return table.schema + def contextual_connect(self, **kwargs): return self diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 98ab60aaa6..4068d18be9 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -167,6 +167,7 @@ class Compiled(object): _cached_metadata = None def __init__(self, dialect, statement, bind=None, + schema_translate_map=None, compile_kwargs=util.immutabledict()): """Construct a new :class:`.Compiled` object. @@ -177,15 +178,24 @@ class Compiled(object): :param bind: Optional Engine or Connection to compile this statement against. + :param schema_translate_map: dictionary of schema names to be + translated when forming the resultant SQL + + .. versionadded:: 1.1 + :param compile_kwargs: additional kwargs that will be passed to the initial call to :meth:`.Compiled.process`. - .. versionadded:: 0.8 """ self.dialect = dialect self.bind = bind + self.preparer = self.dialect.identifier_preparer + if schema_translate_map: + self.preparer = self.preparer._with_schema_translate( + schema_translate_map) + if statement is not None: self.statement = statement self.can_execute = statement.supports_execution @@ -385,8 +395,6 @@ class SQLCompiler(Compiled): self.ctes = None - # an IdentifierPreparer that formats the quoting of identifiers - self.preparer = dialect.identifier_preparer self.label_length = dialect.label_length \ or dialect.max_identifier_length @@ -653,8 +661,16 @@ class SQLCompiler(Compiled): if table is None or not include_table or not table.named_with_column: return name else: - if table.schema: - schema_prefix = self.preparer.quote_schema(table.schema) + '.' + + # inlining of preparer._get_effective_schema + effective_schema = table.schema + if self.preparer.schema_translate_map: + effective_schema = self.preparer.schema_translate_map.get( + effective_schema, effective_schema) + + if effective_schema: + schema_prefix = self.preparer.quote_schema( + effective_schema) + '.' else: schema_prefix = '' tablename = table.name @@ -1814,8 +1830,15 @@ class SQLCompiler(Compiled): def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, fromhints=None, use_schema=True, **kwargs): if asfrom or ashint: - if use_schema and getattr(table, "schema", None): - ret = self.preparer.quote_schema(table.schema) + \ + + # inlining of preparer._get_effective_schema + effective_schema = table.schema + if self.preparer.schema_translate_map: + effective_schema = self.preparer.schema_translate_map.get( + effective_schema, effective_schema) + + if use_schema and effective_schema: + ret = self.preparer.quote_schema(effective_schema) + \ "." + self.preparer.quote(table.name) else: ret = self.preparer.quote(table.name) @@ -2103,10 +2126,6 @@ class DDLCompiler(Compiled): def type_compiler(self): return self.dialect.type_compiler - @property - def preparer(self): - return self.dialect.identifier_preparer - def construct_params(self, params=None): return None @@ -2116,7 +2135,7 @@ class DDLCompiler(Compiled): if isinstance(ddl.target, schema.Table): context = context.copy() - preparer = self.dialect.identifier_preparer + preparer = self.preparer path = preparer.format_table_seq(ddl.target) if len(path) == 1: table, sch = path[0], '' @@ -2142,7 +2161,7 @@ class DDLCompiler(Compiled): def visit_create_table(self, create): table = create.element - preparer = self.dialect.identifier_preparer + preparer = self.preparer text = "\nCREATE " if table._prefixes: @@ -2269,9 +2288,12 @@ class DDLCompiler(Compiled): index, include_schema=True) def _prepared_index_name(self, index, include_schema=False): - if include_schema and index.table is not None and index.table.schema: - schema = index.table.schema - schema_name = self.preparer.quote_schema(schema) + if index.table is not None: + effective_schema = self.preparer._get_effective_schema(index.table) + else: + effective_schema = None + if include_schema and effective_schema: + schema_name = self.preparer.quote_schema(effective_schema) else: schema_name = None @@ -2399,7 +2421,7 @@ class DDLCompiler(Compiled): return text def visit_foreign_key_constraint(self, constraint): - preparer = self.dialect.identifier_preparer + preparer = self.preparer text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -2626,6 +2648,8 @@ class IdentifierPreparer(object): illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS + schema_translate_map = util.immutabledict() + def __init__(self, dialect, initial_quote='"', final_quote=None, escape_quote='"', omit_schema=False): """Construct a new ``IdentifierPreparer`` object. @@ -2650,6 +2674,12 @@ class IdentifierPreparer(object): self.omit_schema = omit_schema self._strings = {} + def _with_schema_translate(self, schema_translate_map): + prep = self.__class__.__new__(self.__class__) + prep.__dict__.update(self.__dict__) + prep.schema_translate_map = schema_translate_map + return prep + def _escape_identifier(self, value): """Escape an identifier. @@ -2722,9 +2752,12 @@ class IdentifierPreparer(object): def format_sequence(self, sequence, use_schema=True): name = self.quote(sequence.name) + + effective_schema = self._get_effective_schema(sequence) + if (not self.omit_schema and use_schema and - sequence.schema is not None): - name = self.quote_schema(sequence.schema) + "." + name + effective_schema is not None): + name = self.quote_schema(effective_schema) + "." + name return name def format_label(self, label, name=None): @@ -2747,15 +2780,25 @@ class IdentifierPreparer(object): return None return self.quote(constraint.name) + def _get_effective_schema(self, table): + effective_schema = table.schema + if self.schema_translate_map: + effective_schema = self.schema_translate_map.get( + effective_schema, effective_schema) + return effective_schema + def format_table(self, table, use_schema=True, name=None): """Prepare a quoted table and schema name.""" if name is None: name = table.name result = self.quote(name) + + effective_schema = self._get_effective_schema(table) + if not self.omit_schema and use_schema \ - and getattr(table, "schema", None): - result = self.quote_schema(table.schema) + "." + result + and effective_schema: + result = self.quote_schema(effective_schema) + "." + result return result def format_schema(self, name, quote=None): @@ -2794,9 +2837,11 @@ class IdentifierPreparer(object): # ('database', 'owner', etc.) could override this and return # a longer sequence. + effective_schema = self._get_effective_schema(table) + if not self.omit_schema and use_schema and \ - getattr(table, 'schema', None): - return (self.quote_schema(table.schema), + effective_schema: + return (self.quote_schema(effective_schema), self.format_table(table, use_schema=False)) else: return (self.format_table(table, use_schema=False), ) diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 71018f132b..7225da5518 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -679,13 +679,16 @@ class SchemaGenerator(DDLBase): def _can_create_table(self, table): self.dialect.validate_identifier(table.name) - if table.schema: - self.dialect.validate_identifier(table.schema) + effective_schema = self.connection._get_effective_schema(table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) return not self.checkfirst or \ not self.dialect.has_table(self.connection, - table.name, schema=table.schema) + table.name, schema=effective_schema) def _can_create_sequence(self, sequence): + effective_schema = self.connection._get_effective_schema(sequence) + return self.dialect.supports_sequences and \ ( (not self.dialect.sequences_optional or @@ -695,7 +698,7 @@ class SchemaGenerator(DDLBase): not self.dialect.has_sequence( self.connection, sequence.name, - schema=sequence.schema) + schema=effective_schema) ) ) @@ -882,12 +885,14 @@ class SchemaDropper(DDLBase): def _can_drop_table(self, table): self.dialect.validate_identifier(table.name) - if table.schema: - self.dialect.validate_identifier(table.schema) + effective_schema = self.connection._get_effective_schema(table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) return not self.checkfirst or self.dialect.has_table( - self.connection, table.name, schema=table.schema) + self.connection, table.name, schema=effective_schema) def _can_drop_sequence(self, sequence): + effective_schema = self.connection._get_effective_schema(sequence) return self.dialect.supports_sequences and \ ((not self.dialect.sequences_optional or not sequence.optional) and @@ -895,7 +900,7 @@ class SchemaDropper(DDLBase): self.dialect.has_sequence( self.connection, sequence.name, - schema=sequence.schema)) + schema=effective_schema)) ) def visit_index(self, index): diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 63667654d4..ad0aa43627 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -273,7 +273,8 @@ class AssertsCompiledSQL(object): check_prefetch=None, use_default_dialect=False, allow_dialect_select=False, - literal_binds=False): + literal_binds=False, + schema_translate_map=None): if use_default_dialect: dialect = default.DefaultDialect() elif allow_dialect_select: @@ -292,6 +293,9 @@ class AssertsCompiledSQL(object): kw = {} compile_kwargs = {} + if schema_translate_map: + kw['schema_translate_map'] = schema_translate_map + if params is not None: kw['column_keys'] = list(params) diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 39d0789855..904149c164 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -87,13 +87,18 @@ class CompiledSQL(SQLMatchRule): compare_dialect = self._compile_dialect(execute_observed) if isinstance(context.compiled.statement, _DDLCompiles): compiled = \ - context.compiled.statement.compile(dialect=compare_dialect) + context.compiled.statement.compile( + dialect=compare_dialect, + schema_translate_map=context. + compiled.preparer.schema_translate_map) else: compiled = ( context.compiled.statement.compile( dialect=compare_dialect, column_keys=context.compiled.column_keys, - inline=context.compiled.inline) + inline=context.compiled.inline, + schema_translate_map=context. + compiled.preparer.schema_translate_map) ) _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled)) parameters = execute_observed.parameters diff --git a/regen_callcounts.tox.ini b/regen_callcounts.tox.ini index e74ceef362..619d46c491 100644 --- a/regen_callcounts.tox.ini +++ b/regen_callcounts.tox.ini @@ -15,10 +15,10 @@ deps=pytest commands= - py{27}-sqla_{cext,nocext}-db_{mysql}: {[base]basecommand} --db mysql {posargs} - py{33,34}-sqla_{cext,nocext}-db_{mysql}: {[base]basecommand} --db pymysql {posargs} - db_{postgresql}: {[base]basecommand} --db postgresql {posargs} - db_{sqlite}: {[base]basecommand} --db sqlite {posargs} + py{27}-sqla_{cext,nocext}-db_{mysql}: {[base]basecommand} --db mysql + py{33,34}-sqla_{cext,nocext}-db_{mysql}: {[base]basecommand} --db pymysql + db_{postgresql}: {[base]basecommand} --db postgresql + db_{sqlite}: {[base]basecommand} --db sqlite # -E : ignore PYTHON* environment variables (such as PYTHONPATH) # -s : don't add user site directory to sys.path; also PYTHONNOUSERSITE diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 71d8fa3e56..87e48d3f29 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -169,6 +169,24 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "VARCHAR(1), CHECK (somecolumn IN ('x', " "'y', 'z')))") + def test_create_type_schema_translate(self): + e1 = Enum('x', 'y', 'z', name='somename') + e2 = Enum('x', 'y', 'z', name='somename', schema='someschema') + schema_translate_map = {None: "foo", "someschema": "bar"} + + self.assert_compile( + postgresql.CreateEnumType(e1), + "CREATE TYPE foo.somename AS ENUM ('x', 'y', 'z')", + schema_translate_map=schema_translate_map + ) + + self.assert_compile( + postgresql.CreateEnumType(e2), + "CREATE TYPE bar.somename AS ENUM ('x', 'y', 'z')", + schema_translate_map=schema_translate_map + ) + + def test_create_table_with_tablespace(self): m = MetaData() tbl = Table( diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index fbb1878dc6..5ea5d3515f 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -21,6 +21,8 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing.mock import Mock, call, patch from contextlib import contextmanager from sqlalchemy.util import nested +from sqlalchemy.testing.assertsql import CompiledSQL + users, metadata, users_autoinc = None, None, None @@ -805,6 +807,40 @@ class CompiledCacheTest(fixtures.TestBase): eq_(compile_mock.call_count, 1) eq_(len(cache), 1) + @testing.requires.schemas + @testing.provide_metadata + def test_schema_translate_in_key(self): + Table( + 'x', self.metadata, Column('q', Integer)) + Table( + 'x', self.metadata, Column('q', Integer), + schema=config.test_schema) + self.metadata.create_all() + + m = MetaData() + t1 = Table('x', m, Column('q', Integer)) + ins = t1.insert() + stmt = select([t1.c.q]) + + cache = {} + with config.db.connect().execution_options( + compiled_cache=cache, + ) as conn: + conn.execute(ins, {"q": 1}) + eq_(conn.scalar(stmt), 1) + + with config.db.connect().execution_options( + compiled_cache=cache, + schema_translate_map={None: config.test_schema} + ) as conn: + conn.execute(ins, {"q": 2}) + eq_(conn.scalar(stmt), 2) + + with config.db.connect().execution_options( + compiled_cache=cache, + ) as conn: + eq_(conn.scalar(stmt), 1) + class MockStrategyTest(fixtures.TestBase): @@ -989,6 +1025,156 @@ class ResultProxyTest(fixtures.TestBase): finally: r.close() +class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): + __requires__ = 'schemas', + __backend__ = True + + def test_create_table(self): + map_ = { + None: config.test_schema, + "foo": config.test_schema, "bar": None} + + metadata = MetaData() + t1 = Table('t1', metadata, Column('x', Integer)) + t2 = Table('t2', metadata, Column('x', Integer), schema="foo") + t3 = Table('t3', metadata, Column('x', Integer), schema="bar") + + with self.sql_execution_asserter(config.db) as asserter: + with config.db.connect().execution_options( + schema_translate_map=map_) as conn: + + t1.create(conn) + t2.create(conn) + t3.create(conn) + + t3.drop(conn) + t2.drop(conn) + t1.drop(conn) + + asserter.assert_( + CompiledSQL("CREATE TABLE %s.t1 (x INTEGER)" % config.test_schema), + CompiledSQL("CREATE TABLE %s.t2 (x INTEGER)" % config.test_schema), + CompiledSQL("CREATE TABLE t3 (x INTEGER)"), + CompiledSQL("DROP TABLE t3"), + CompiledSQL("DROP TABLE %s.t2" % config.test_schema), + CompiledSQL("DROP TABLE %s.t1" % config.test_schema) + ) + + def _fixture(self): + metadata = self.metadata + Table( + 't1', metadata, Column('x', Integer), + schema=config.test_schema) + Table( + 't2', metadata, Column('x', Integer), + schema=config.test_schema) + Table('t3', metadata, Column('x', Integer), schema=None) + metadata.create_all() + + def test_ddl_hastable(self): + + map_ = { + None: config.test_schema, + "foo": config.test_schema, "bar": None} + + metadata = MetaData() + Table('t1', metadata, Column('x', Integer)) + Table('t2', metadata, Column('x', Integer), schema="foo") + Table('t3', metadata, Column('x', Integer), schema="bar") + + with config.db.connect().execution_options( + schema_translate_map=map_) as conn: + metadata.create_all(conn) + + assert config.db.has_table('t1', schema=config.test_schema) + assert config.db.has_table('t2', schema=config.test_schema) + assert config.db.has_table('t3', schema=None) + + with config.db.connect().execution_options( + schema_translate_map=map_) as conn: + metadata.drop_all(conn) + + assert not config.db.has_table('t1', schema=config.test_schema) + assert not config.db.has_table('t2', schema=config.test_schema) + assert not config.db.has_table('t3', schema=None) + + @testing.provide_metadata + def test_crud(self): + self._fixture() + + map_ = { + None: config.test_schema, + "foo": config.test_schema, "bar": None} + + metadata = MetaData() + t1 = Table('t1', metadata, Column('x', Integer)) + t2 = Table('t2', metadata, Column('x', Integer), schema="foo") + t3 = Table('t3', metadata, Column('x', Integer), schema="bar") + + with self.sql_execution_asserter(config.db) as asserter: + with config.db.connect().execution_options( + schema_translate_map=map_) as conn: + + conn.execute(t1.insert(), {'x': 1}) + conn.execute(t2.insert(), {'x': 1}) + conn.execute(t3.insert(), {'x': 1}) + + conn.execute(t1.update().values(x=1).where(t1.c.x == 1)) + conn.execute(t2.update().values(x=2).where(t2.c.x == 1)) + conn.execute(t3.update().values(x=3).where(t3.c.x == 1)) + + eq_(conn.scalar(select([t1.c.x])), 1) + eq_(conn.scalar(select([t2.c.x])), 2) + eq_(conn.scalar(select([t3.c.x])), 3) + + conn.execute(t1.delete()) + conn.execute(t2.delete()) + conn.execute(t3.delete()) + + asserter.assert_( + CompiledSQL( + "INSERT INTO %s.t1 (x) VALUES (:x)" % config.test_schema), + CompiledSQL( + "INSERT INTO %s.t2 (x) VALUES (:x)" % config.test_schema), + CompiledSQL( + "INSERT INTO t3 (x) VALUES (:x)"), + CompiledSQL( + "UPDATE %s.t1 SET x=:x WHERE %s.t1.x = :x_1" % ( + config.test_schema, config.test_schema)), + CompiledSQL( + "UPDATE %s.t2 SET x=:x WHERE %s.t2.x = :x_1" % ( + config.test_schema, config.test_schema)), + CompiledSQL("UPDATE t3 SET x=:x WHERE t3.x = :x_1"), + CompiledSQL("SELECT %s.t1.x FROM %s.t1" % ( + config.test_schema, config.test_schema)), + CompiledSQL("SELECT %s.t2.x FROM %s.t2" % ( + config.test_schema, config.test_schema)), + CompiledSQL("SELECT t3.x FROM t3"), + CompiledSQL("DELETE FROM %s.t1" % config.test_schema), + CompiledSQL("DELETE FROM %s.t2" % config.test_schema), + CompiledSQL("DELETE FROM t3") + ) + + @testing.provide_metadata + def test_via_engine(self): + self._fixture() + + map_ = { + None: config.test_schema, + "foo": config.test_schema, "bar": None} + + metadata = MetaData() + t2 = Table('t2', metadata, Column('x', Integer), schema="foo") + + with self.sql_execution_asserter(config.db) as asserter: + eng = config.db.execution_options(schema_translate_map=map_) + conn = eng.connect() + conn.execute(select([t2.c.x])) + asserter.assert_( + CompiledSQL("SELECT %s.t2.x FROM %s.t2" % ( + config.test_schema, config.test_schema)), + ) + class ExecutionOptionsTest(fixtures.TestBase): diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index b7bf87d63d..f9799fda0f 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -1,16 +1,15 @@ -import operator - import unicodedata import sqlalchemy as sa -from sqlalchemy import schema, events, event, inspect +from sqlalchemy import schema, inspect from sqlalchemy import MetaData, Integer, String -from sqlalchemy.testing import (ComparesTables, engines, AssertsCompiledSQL, +from sqlalchemy.testing import ( + ComparesTables, engines, AssertsCompiledSQL, fixtures, skip) from sqlalchemy.testing.schema import Table, Column from sqlalchemy.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import testing from sqlalchemy.util import ue - +from sqlalchemy.testing import config metadata, users = None, None @@ -1345,6 +1344,18 @@ class SchemaTest(fixtures.TestBase): metadata.drop_all() @testing.requires.schemas + @testing.provide_metadata + def test_schema_translation(self): + Table('foob', self.metadata, Column('q', Integer), schema=config.test_schema) + self.metadata.create_all() + + m = MetaData() + map_ = {"foob": config.test_schema} + with config.db.connect().execution_options(schema_translate_map=map_) as conn: + t = Table('foob', m, schema="foob", autoload_with=conn) + eq_(t.schema, "foob") + eq_(t.c.keys(), ['q']) + @testing.requires.schemas @testing.fails_on('sybase', 'FIXME: unknown') def test_explicit_default_schema_metadata(self): engine = testing.db diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index ffd13309b4..5d082175a4 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -2955,6 +2955,57 @@ class DDLTest(fixtures.TestBase, AssertsCompiledSQL): "CREATE TABLE t1 (q INTEGER, CHECK (a = 1))" ) + def test_schema_translate_map_table(self): + m = MetaData() + t1 = Table('t1', m, Column('q', Integer)) + t2 = Table('t2', m, Column('q', Integer), schema='foo') + t3 = Table('t3', m, Column('q', Integer), schema='bar') + + schema_translate_map = {None: "z", "bar": None, "foo": "bat"} + + self.assert_compile( + schema.CreateTable(t1), + "CREATE TABLE z.t1 (q INTEGER)", + schema_translate_map=schema_translate_map + ) + + self.assert_compile( + schema.CreateTable(t2), + "CREATE TABLE bat.t2 (q INTEGER)", + schema_translate_map=schema_translate_map + ) + + self.assert_compile( + schema.CreateTable(t3), + "CREATE TABLE t3 (q INTEGER)", + schema_translate_map=schema_translate_map + ) + + def test_schema_translate_map_sequence(self): + s1 = schema.Sequence('s1') + s2 = schema.Sequence('s2', schema='foo') + s3 = schema.Sequence('s3', schema='bar') + + schema_translate_map = {None: "z", "bar": None, "foo": "bat"} + + self.assert_compile( + schema.CreateSequence(s1), + "CREATE SEQUENCE z.s1", + schema_translate_map=schema_translate_map + ) + + self.assert_compile( + schema.CreateSequence(s2), + "CREATE SEQUENCE bat.s2", + schema_translate_map=schema_translate_map + ) + + self.assert_compile( + schema.CreateSequence(s3), + "CREATE SEQUENCE s3", + schema_translate_map=schema_translate_map + ) + class InlineDefaultTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = 'default' @@ -3049,6 +3100,82 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): ' "dbo.remote_owner".remotetable' ) + def test_schema_translate_select(self): + schema_translate_map = {"remote_owner": "foob", None: 'bar'} + + self.assert_compile( + table1.select().where(table1.c.name == 'hi'), + "SELECT bar.mytable.myid, bar.mytable.name, " + "bar.mytable.description FROM bar.mytable " + "WHERE bar.mytable.name = :name_1", + schema_translate_map=schema_translate_map + ) + + self.assert_compile( + table4.select().where(table4.c.value == 'hi'), + "SELECT foob.remotetable.rem_id, foob.remotetable.datatype_id, " + "foob.remotetable.value FROM foob.remotetable " + "WHERE foob.remotetable.value = :value_1", + schema_translate_map=schema_translate_map + ) + + schema_translate_map = {"remote_owner": "foob"} + self.assert_compile( + select([ + table1, table4 + ]).select_from( + join(table1, table4, table1.c.myid == table4.c.rem_id) + ), + "SELECT mytable.myid, mytable.name, mytable.description, " + "foob.remotetable.rem_id, foob.remotetable.datatype_id, " + "foob.remotetable.value FROM mytable JOIN foob.remotetable " + "ON foob.remotetable.rem_id = mytable.myid", + schema_translate_map=schema_translate_map + ) + + def test_schema_translate_crud(self): + schema_translate_map = {"remote_owner": "foob", None: 'bar'} + + self.assert_compile( + table1.insert().values(description='foo'), + "INSERT INTO bar.mytable (description) VALUES (:description)", + schema_translate_map=schema_translate_map + ) + + self.assert_compile( + table1.update().where(table1.c.name == 'hi'). + values(description='foo'), + "UPDATE bar.mytable SET description=:description " + "WHERE bar.mytable.name = :name_1", + schema_translate_map=schema_translate_map + ) + self.assert_compile( + table1.delete().where(table1.c.name == 'hi'), + "DELETE FROM bar.mytable WHERE bar.mytable.name = :name_1", + schema_translate_map=schema_translate_map + ) + + self.assert_compile( + table4.insert().values(value='there'), + "INSERT INTO foob.remotetable (value) VALUES (:value)", + schema_translate_map=schema_translate_map + ) + + self.assert_compile( + table4.update().where(table4.c.value == 'hi'). + values(value='there'), + "UPDATE foob.remotetable SET value=:value " + "WHERE foob.remotetable.value = :value_1", + schema_translate_map=schema_translate_map + ) + + self.assert_compile( + table4.delete().where(table4.c.value == 'hi'), + "DELETE FROM foob.remotetable WHERE " + "foob.remotetable.value = :value_1", + schema_translate_map=schema_translate_map + ) + def test_alias(self): a = alias(table4, 'remtable') self.assert_compile(a.select(a.c.datatype_id == 7),