From: Michael Trier Date: Sat, 11 Apr 2009 21:36:45 +0000 (+0000) Subject: Added multi part schema name support. Closes #594 and #1341. X-Git-Tag: rel_0_5_4~26 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=2a962802de28615f5c961b423e1a995b7bd691bc;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Added multi part schema name support. Closes #594 and #1341. --- diff --git a/CHANGES b/CHANGES index 2a82a58e34..6a5189399d 100644 --- a/CHANGES +++ b/CHANGES @@ -40,7 +40,13 @@ CHANGES key attribute on an item contained within a collection owned by an object being deleted would not be set to None if the relation() was self-referential. [ticket:1376] - + +- schema + - Added a quote_schema() method to the IdentifierPreparer class + so that dialects can override how schemas get handled. This + enables the MSSQL dialect to treat schemas as multipart + identifiers, such as 'database.owner'. [ticket: 594, 1341] + - sql - ``sqlalchemy.extract()`` is now dialect sensitive and can extract components of timestamps idiomatically across the diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 03cf73eee3..396e8dd241 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -1733,6 +1733,11 @@ class MSSQLIdentifierPreparer(compiler.IdentifierPreparer): #TODO: determine MSSQL's escaping rules return value + def quote_schema(self, schema, force=True): + """Prepare a quoted table and schema name.""" + result = '.'.join([self.quote(x, force) for x in schema.split('.')]) + return result + dialect = MSSQLDialect dialect.statement_compiler = MSSQLCompiler dialect.schemagenerator = MSSQLSchemaGenerator diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index fd99d2de70..47c01024c7 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -881,17 +881,33 @@ class ForeignKey(SchemaItem): raise exc.ArgumentError( "Parent column '%s' does not descend from a " "table-attached Column" % str(self.parent)) - m = re.match(r"^(.+?)(?:\.(.+?))?(?:\.(.+?))?$", self._colspec, - re.UNICODE) + + m = self._colspec.split('.') + if m is None: raise exc.ArgumentError( "Invalid foreign key column specification: %s" % self._colspec) - if m.group(3) is None: - (tname, colname) = m.group(1, 2) - schema = None + + # A FK between column 'bar' and table 'foo' can be + # specified as 'foo', 'foo.bar', 'dbo.foo.bar', + # 'otherdb.dbo.foo.bar'. Once we have the column name and + # the table name, treat everything else as the schema + # name. Some databases (e.g. Sybase) support + # inter-database foreign keys. See tickets#1341 and -- + # indirectly related -- Ticket #594. This assumes that '.' + # will never appear *within* any component of the FK. + + (schema, tname, colname) = (None, None, None) + if (len(m) == 1): + tname = m.pop() else: - (schema, tname, colname) = m.group(1, 2, 3) + colname = m.pop() + tname = m.pop() + + if (len(m) > 0): + schema = '.'.join(m) + if _get_table_key(tname, schema) not in parenttable.metadata: raise exc.NoReferencedTableError( "Could not find table '%s' with which to generate a " diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5042959b25..84b0ff6283 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -292,7 +292,7 @@ class DefaultCompiler(engine.Compiled): return name else: if column.table.schema: - schema_prefix = self.preparer.quote(column.table.schema, column.table.quote_schema) + '.' + schema_prefix = self.preparer.quote_schema(column.table.schema, column.table.quote_schema) + '.' else: schema_prefix = '' tablename = column.table.name @@ -613,7 +613,7 @@ class DefaultCompiler(engine.Compiled): def visit_table(self, table, asfrom=False, **kwargs): if asfrom: if getattr(table, "schema", None): - return self.preparer.quote(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote) + return self.preparer.quote_schema(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote) else: return self.preparer.quote(table.name, table.quote) else: @@ -1094,7 +1094,15 @@ class IdentifierPreparer(object): or self.illegal_initial_characters.match(value[0]) or not self.legal_characters.match(unicode(value)) or (lc_value != value)) - + + def quote_schema(self, schema, force): + """Quote a schema. + + Subclasses should override this to provide database-dependent + quoting behavior. + """ + return self.quote(schema, force) + def quote(self, ident, force): if force is None: if ident in self._strings: @@ -1113,7 +1121,7 @@ class IdentifierPreparer(object): def format_sequence(self, sequence, use_schema=True): name = self.quote(sequence.name, sequence.quote) if not self.omit_schema and use_schema and sequence.schema is not None: - name = self.quote(sequence.schema, sequence.quote) + "." + name + name = self.quote_schema(sequence.schema, sequence.quote) + "." + name return name def format_label(self, label, name=None): @@ -1135,7 +1143,7 @@ class IdentifierPreparer(object): name = table.name result = self.quote(name, table.quote) if not self.omit_schema and use_schema and getattr(table, "schema", None): - result = self.quote(table.schema, table.quote_schema) + "." + result + result = self.quote_schema(table.schema, table.quote_schema) + "." + result return result def format_column(self, column, use_table=False, name=None, table_name=None): @@ -1163,7 +1171,7 @@ class IdentifierPreparer(object): # a longer sequence. if not self.omit_schema and use_schema and getattr(table, 'schema', None): - return (self.quote(table.schema, table.quote_schema), + return (self.quote_schema(table.schema, table.quote_schema), self.format_table(table, use_schema=False)) else: return (self.format_table(table, use_schema=False), ) diff --git a/test/dialect/mssql.py b/test/dialect/mssql.py index de9c5cd62b..50f9594ef3 100755 --- a/test/dialect/mssql.py +++ b/test/dialect/mssql.py @@ -88,6 +88,30 @@ class CompileTest(TestBase, AssertsCompiledSQL): s = select([tbl.c.id]).where(tbl.c.id==1) self.assert_compile(tbl.delete().where(tbl.c.id==(s)), "DELETE FROM paj.test WHERE paj.test.id IN (SELECT test_1.id FROM paj.test AS test_1 WHERE test_1.id = :id_1)") + def test_delete_schema_multipart(self): + metadata = MetaData() + tbl = Table('test', metadata, Column('id', Integer, primary_key=True), schema='banana.paj') + self.assert_compile(tbl.delete(tbl.c.id == 1), "DELETE FROM banana.paj.test WHERE banana.paj.test.id = :id_1") + + s = select([tbl.c.id]).where(tbl.c.id==1) + self.assert_compile(tbl.delete().where(tbl.c.id==(s)), "DELETE FROM banana.paj.test WHERE banana.paj.test.id IN (SELECT test_1.id FROM banana.paj.test AS test_1 WHERE test_1.id = :id_1)") + + def test_delete_schema_multipart_needs_quoting(self): + metadata = MetaData() + tbl = Table('test', metadata, Column('id', Integer, primary_key=True), schema='banana split.paj') + self.assert_compile(tbl.delete(tbl.c.id == 1), "DELETE FROM [banana split].paj.test WHERE [banana split].paj.test.id = :id_1") + + s = select([tbl.c.id]).where(tbl.c.id==1) + self.assert_compile(tbl.delete().where(tbl.c.id==(s)), "DELETE FROM [banana split].paj.test WHERE [banana split].paj.test.id IN (SELECT test_1.id FROM [banana split].paj.test AS test_1 WHERE test_1.id = :id_1)") + + def test_delete_schema_multipart_both_need_quoting(self): + metadata = MetaData() + tbl = Table('test', metadata, Column('id', Integer, primary_key=True), schema='banana split.paj with a space') + self.assert_compile(tbl.delete(tbl.c.id == 1), "DELETE FROM [banana split].[paj with a space].test WHERE [banana split].[paj with a space].test.id = :id_1") + + s = select([tbl.c.id]).where(tbl.c.id==1) + self.assert_compile(tbl.delete().where(tbl.c.id==(s)), "DELETE FROM [banana split].[paj with a space].test WHERE [banana split].[paj with a space].test.id IN (SELECT test_1.id FROM [banana split].[paj with a space].test AS test_1 WHERE test_1.id = :id_1)") + def test_union(self): t1 = table('t1', column('col1'),