]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Corrected the "has_sequence" query to take current schema,
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 21 Oct 2009 04:47:02 +0000 (04:47 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 21 Oct 2009 04:47:02 +0000 (04:47 +0000)
or explicit sequence-stated schema, into account.
[ticket:1576]

CHANGES
lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/ddl.py
test/engine/test_reflection.py

diff --git a/CHANGES b/CHANGES
index f83ea987a705222fbb705b36f0b194ab76b6fed8..480350fdf2065f02077489bc203c845e72ac684b 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -597,6 +597,10 @@ CHANGES
       This is postgresql.DOUBLE_PRECISION in 0.6.
       [ticket:1085]
 
+    - Corrected the "has_sequence" query to take current schema,
+      or explicit sequence-stated schema, into account.
+      [ticket:1576]
+
 - mssql
     - Changed the name of TrustedConnection to
       Trusted_Connection when constructing pyodbc connect
index 4c2c7f913f26bf941fe7494810ce7a2c49bd46ce..232583c39eb5ca23cccfa7570d015f2bc5452655 100644 (file)
@@ -378,7 +378,7 @@ class FBDialect(default.DefaultDialect):
         c = connection.execute(tblqry, [self.denormalize_name(table_name)])
         return c.first() is not None
 
-    def has_sequence(self, connection, sequence_name):
+    def has_sequence(self, connection, sequence_name, schema=None):
         """Return ``True`` if the given sequence (generator) exists."""
 
         genqry = """
index 0bc5f08b028457b78baa0fa8f055d43b8b60b877..061c3b066eb741a2f4f8dbe1d6d637c32f009e35 100644 (file)
@@ -539,13 +539,23 @@ class PGDialect(default.DefaultDialect):
             )
         return bool(cursor.first())
 
-    def has_sequence(self, connection, sequence_name):
-        cursor = connection.execute(
-                    sql.text("SELECT relname FROM pg_class WHERE relkind = 'S' AND "
-                        "relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%' "
-                        "AND nspname != 'information_schema' AND relname = :seqname)", 
-                        bindparams=[sql.bindparam('seqname', unicode(sequence_name), type_=sqltypes.Unicode)]
-                    ))
+    def has_sequence(self, connection, sequence_name, schema=None):
+        if schema is None:
+            cursor = connection.execute(
+                        sql.text("SELECT relname FROM pg_class c join pg_namespace n on "
+                            "n.oid=c.relnamespace where relkind='S' and n.nspname=current_schema() and lower(relname)=:name",
+                            bindparams=[sql.bindparam('name', unicode(sequence_name.lower()), type_=sqltypes.Unicode)] 
+                        )
+                    )
+        else:
+            cursor = connection.execute(
+                        sql.text("SELECT relname FROM pg_class c join pg_namespace n on "
+                            "n.oid=c.relnamespace where relkind='S' and n.nspname=:schema and lower(relname)=:name",
+                            bindparams=[sql.bindparam('name', unicode(sequence_name.lower()), type_=sqltypes.Unicode),
+                                sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode)] 
+                        )
+                    )
+
         return bool(cursor.first())
 
     def table_names(self, connection, schema):
index 6e7253e9a700a653fedcd3100d69a7899c5c90a1..f5cff1a2045be34a8cbd6d6c950008a3a6a00c44 100644 (file)
@@ -66,7 +66,7 @@ class SchemaGenerator(DDLBase):
             if ((not self.dialect.sequences_optional or
                  not sequence.optional) and
                 (not self.checkfirst or
-                 not self.dialect.has_sequence(self.connection, sequence.name))):
+                 not self.dialect.has_sequence(self.connection, sequence.name, schema=sequence.schema))):
                 self.connection.execute(schema.CreateSequence(sequence))
 
     def visit_index(self, index):
@@ -124,5 +124,5 @@ class SchemaDropper(DDLBase):
             if ((not self.dialect.sequences_optional or
                  not sequence.optional) and
                 (not self.checkfirst or
-                 self.dialect.has_sequence(self.connection, sequence.name))):
+                 self.dialect.has_sequence(self.connection, sequence.name, schema=sequence.schema))):
                 self.connection.execute(schema.DropSequence(sequence))
index 48562f6fc7814d37ddf81341b3c2ee37aec810e2..0bf5a9259d58512e6b8f7dd21cefd1be14aa798b 100644 (file)
@@ -869,22 +869,38 @@ class SchemaTest(TestBase):
 
 
 class HasSequenceTest(TestBase):
-    @classmethod
-    def setup_class(cls):
-        global metadata, users
+
+    @testing.requires.sequences
+    def test_has_sequence(self):
         metadata = MetaData()
         users = Table('users', metadata,
                       Column('user_id', sa.Integer, sa.Sequence('user_id_seq'), primary_key=True),
                       Column('user_name', sa.String(40)),
                       )
+        metadata.create_all(bind=testing.db)
+        try:
+            eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), True)
+        finally:
+            metadata.drop_all(bind=testing.db)
+        eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False)
 
+    @testing.requires.schemas
     @testing.requires.sequences
-    def test_hassequence(self):
-        metadata.create_all(bind=testing.db)
+    def test_has_sequence_schema(self):
+        test_schema = get_schema()
+        s1 = sa.Sequence('user_id_seq', schema=test_schema)
+        s2 = sa.Sequence('user_id_seq')
+        testing.db.execute(schema.CreateSequence(s1))
+        testing.db.execute(schema.CreateSequence(s2))
+        eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq', schema=test_schema), True)
         eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), True)
-        metadata.drop_all(bind=testing.db)
+        testing.db.execute(schema.DropSequence(s1))
+        eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq', schema=test_schema), False)
+        eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), True)
+        testing.db.execute(schema.DropSequence(s2))
+        eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq', schema=test_schema), False)
         eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False)
-
+        
 # Tests related to engine.reflection
 
 def get_schema():